Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/airflow/utils/db.py: 16%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

735 statements  

1# 

2# Licensed to the Apache Software Foundation (ASF) under one 

3# or more contributor license agreements. See the NOTICE file 

4# distributed with this work for additional information 

5# regarding copyright ownership. The ASF licenses this file 

6# to you under the Apache License, Version 2.0 (the 

7# "License"); you may not use this file except in compliance 

8# with the License. You may obtain a copy of the License at 

9# 

10# http://www.apache.org/licenses/LICENSE-2.0 

11# 

12# Unless required by applicable law or agreed to in writing, 

13# software distributed under the License is distributed on an 

14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 

15# KIND, either express or implied. See the License for the 

16# specific language governing permissions and limitations 

17# under the License. 

18from __future__ import annotations 

19 

20import collections.abc 

21import contextlib 

22import enum 

23import itertools 

24import json 

25import logging 

26import os 

27import sys 

28import time 

29import warnings 

30from dataclasses import dataclass 

31from tempfile import gettempdir 

32from typing import ( 

33 TYPE_CHECKING, 

34 Any, 

35 Callable, 

36 Generator, 

37 Iterable, 

38 Iterator, 

39 Protocol, 

40 Sequence, 

41 TypeVar, 

42 overload, 

43) 

44 

45import attrs 

46from sqlalchemy import ( 

47 Table, 

48 and_, 

49 column, 

50 delete, 

51 exc, 

52 func, 

53 inspect, 

54 literal, 

55 or_, 

56 select, 

57 table, 

58 text, 

59 tuple_, 

60) 

61 

62import airflow 

63from airflow import settings 

64from airflow.configuration import conf 

65from airflow.exceptions import AirflowException 

66from airflow.models import import_all_models 

67from airflow.utils import helpers 

68 

69# TODO: remove create_session once we decide to break backward compatibility 

70from airflow.utils.session import NEW_SESSION, create_session, provide_session # noqa: F401 

71from airflow.utils.task_instance_session import get_current_task_instance_session 

72 

73if TYPE_CHECKING: 

74 from alembic.runtime.environment import EnvironmentContext 

75 from alembic.script import ScriptDirectory 

76 from sqlalchemy.engine import Row 

77 from sqlalchemy.orm import Query, Session 

78 from sqlalchemy.sql.elements import ClauseElement, TextClause 

79 from sqlalchemy.sql.selectable import Select 

80 

81 from airflow.models.connection import Connection 

82 from airflow.typing_compat import Self 

83 

84 # TODO: Import this from sqlalchemy.orm instead when switching to SQLA 2. 

85 # https://docs.sqlalchemy.org/en/20/orm/mapping_api.html#sqlalchemy.orm.MappedClassProtocol 

86 class MappedClassProtocol(Protocol): 

87 """Protocol for SQLALchemy model base.""" 

88 

89 __tablename__: str 

90 

91 

92T = TypeVar("T") 

93 

94log = logging.getLogger(__name__) 

95 

96_REVISION_HEADS_MAP = { 

97 "2.0.0": "e959f08ac86c", 

98 "2.0.1": "82b7c48c147f", 

99 "2.0.2": "2e42bb497a22", 

100 "2.1.0": "a13f7613ad25", 

101 "2.1.3": "97cdd93827b8", 

102 "2.1.4": "ccde3e26fe78", 

103 "2.2.0": "7b2661a43ba3", 

104 "2.2.3": "be2bfac3da23", 

105 "2.2.4": "587bdf053233", 

106 "2.3.0": "b1b348e02d07", 

107 "2.3.1": "1de7bc13c950", 

108 "2.3.2": "3c94c427fdf6", 

109 "2.3.3": "f5fcbda3e651", 

110 "2.4.0": "ecb43d2a1842", 

111 "2.4.2": "b0d31815b5a6", 

112 "2.4.3": "e07f49787c9d", 

113 "2.5.0": "290244fb8b83", 

114 "2.6.0": "98ae134e6fff", 

115 "2.6.2": "c804e5c76e3e", 

116 "2.7.0": "405de8318b3a", 

117 "2.8.0": "10b52ebd31f7", 

118 "2.8.1": "88344c1d9134", 

119 "2.9.0": "1949afb29106", 

120 "2.9.2": "0fd0c178cbe8", 

121 "2.10.0": "c4602ba06b4b", 

122} 

123 

124 

125def _format_airflow_moved_table_name(source_table, version, category): 

126 return "__".join([settings.AIRFLOW_MOVED_TABLE_PREFIX, version.replace(".", "_"), category, source_table]) 

127 

128 

129@provide_session 

130def merge_conn(conn: Connection, session: Session = NEW_SESSION): 

131 """Add new Connection.""" 

132 if not session.scalar(select(1).where(conn.__class__.conn_id == conn.conn_id)): 

133 session.add(conn) 

134 session.commit() 

135 

136 

137@provide_session 

138def add_default_pool_if_not_exists(session: Session = NEW_SESSION): 

139 """Add default pool if it does not exist.""" 

140 from airflow.models.pool import Pool 

141 

142 if not Pool.get_pool(Pool.DEFAULT_POOL_NAME, session=session): 

143 default_pool = Pool( 

144 pool=Pool.DEFAULT_POOL_NAME, 

145 slots=conf.getint(section="core", key="default_pool_task_slot_count"), 

146 description="Default pool", 

147 include_deferred=False, 

148 ) 

149 session.add(default_pool) 

150 session.commit() 

151 

152 

153@provide_session 

154def create_default_connections(session: Session = NEW_SESSION): 

155 """Create default Airflow connections.""" 

156 from airflow.models.connection import Connection 

157 

158 merge_conn( 

159 Connection( 

160 conn_id="airflow_db", 

161 conn_type="mysql", 

162 host="mysql", 

163 login="root", 

164 password="", 

165 schema="airflow", 

166 ), 

167 session, 

168 ) 

169 merge_conn( 

170 Connection( 

171 conn_id="athena_default", 

172 conn_type="athena", 

173 ), 

174 session, 

175 ) 

176 merge_conn( 

177 Connection( 

178 conn_id="aws_default", 

179 conn_type="aws", 

180 ), 

181 session, 

182 ) 

183 merge_conn( 

184 Connection( 

185 conn_id="azure_batch_default", 

186 conn_type="azure_batch", 

187 login="<ACCOUNT_NAME>", 

188 password="", 

189 extra="""{"account_url": "<ACCOUNT_URL>"}""", 

190 ) 

191 ) 

192 merge_conn( 

193 Connection( 

194 conn_id="azure_cosmos_default", 

195 conn_type="azure_cosmos", 

196 extra='{"database_name": "<DATABASE_NAME>", "collection_name": "<COLLECTION_NAME>" }', 

197 ), 

198 session, 

199 ) 

200 merge_conn( 

201 Connection( 

202 conn_id="azure_data_explorer_default", 

203 conn_type="azure_data_explorer", 

204 host="https://<CLUSTER>.kusto.windows.net", 

205 extra="""{"auth_method": "<AAD_APP | AAD_APP_CERT | AAD_CREDS | AAD_DEVICE>", 

206 "tenant": "<TENANT ID>", "certificate": "<APPLICATION PEM CERTIFICATE>", 

207 "thumbprint": "<APPLICATION CERTIFICATE THUMBPRINT>"}""", 

208 ), 

209 session, 

210 ) 

211 merge_conn( 

212 Connection( 

213 conn_id="azure_data_lake_default", 

214 conn_type="azure_data_lake", 

215 extra='{"tenant": "<TENANT>", "account_name": "<ACCOUNTNAME>" }', 

216 ), 

217 session, 

218 ) 

219 merge_conn( 

220 Connection( 

221 conn_id="azure_default", 

222 conn_type="azure", 

223 ), 

224 session, 

225 ) 

226 merge_conn( 

227 Connection( 

228 conn_id="cassandra_default", 

229 conn_type="cassandra", 

230 host="cassandra", 

231 port=9042, 

232 ), 

233 session, 

234 ) 

235 merge_conn( 

236 Connection( 

237 conn_id="databricks_default", 

238 conn_type="databricks", 

239 host="localhost", 

240 ), 

241 session, 

242 ) 

243 merge_conn( 

244 Connection( 

245 conn_id="dingding_default", 

246 conn_type="http", 

247 host="", 

248 password="", 

249 ), 

250 session, 

251 ) 

252 merge_conn( 

253 Connection( 

254 conn_id="drill_default", 

255 conn_type="drill", 

256 host="localhost", 

257 port=8047, 

258 extra='{"dialect_driver": "drill+sadrill", "storage_plugin": "dfs"}', 

259 ), 

260 session, 

261 ) 

262 merge_conn( 

263 Connection( 

264 conn_id="druid_broker_default", 

265 conn_type="druid", 

266 host="druid-broker", 

267 port=8082, 

268 extra='{"endpoint": "druid/v2/sql"}', 

269 ), 

270 session, 

271 ) 

272 merge_conn( 

273 Connection( 

274 conn_id="druid_ingest_default", 

275 conn_type="druid", 

276 host="druid-overlord", 

277 port=8081, 

278 extra='{"endpoint": "druid/indexer/v1/task"}', 

279 ), 

280 session, 

281 ) 

282 merge_conn( 

283 Connection( 

284 conn_id="elasticsearch_default", 

285 conn_type="elasticsearch", 

286 host="localhost", 

287 schema="http", 

288 port=9200, 

289 ), 

290 session, 

291 ) 

292 merge_conn( 

293 Connection( 

294 conn_id="emr_default", 

295 conn_type="emr", 

296 extra=""" 

297 { "Name": "default_job_flow_name", 

298 "LogUri": "s3://my-emr-log-bucket/default_job_flow_location", 

299 "ReleaseLabel": "emr-4.6.0", 

300 "Instances": { 

301 "Ec2KeyName": "mykey", 

302 "Ec2SubnetId": "somesubnet", 

303 "InstanceGroups": [ 

304 { 

305 "Name": "Master nodes", 

306 "Market": "ON_DEMAND", 

307 "InstanceRole": "MASTER", 

308 "InstanceType": "r3.2xlarge", 

309 "InstanceCount": 1 

310 }, 

311 { 

312 "Name": "Core nodes", 

313 "Market": "ON_DEMAND", 

314 "InstanceRole": "CORE", 

315 "InstanceType": "r3.2xlarge", 

316 "InstanceCount": 1 

317 } 

318 ], 

319 "TerminationProtected": false, 

320 "KeepJobFlowAliveWhenNoSteps": false 

321 }, 

322 "Applications":[ 

323 { "Name": "Spark" } 

324 ], 

325 "VisibleToAllUsers": true, 

326 "JobFlowRole": "EMR_EC2_DefaultRole", 

327 "ServiceRole": "EMR_DefaultRole", 

328 "Tags": [ 

329 { 

330 "Key": "app", 

331 "Value": "analytics" 

332 }, 

333 { 

334 "Key": "environment", 

335 "Value": "development" 

336 } 

337 ] 

338 } 

339 """, 

340 ), 

341 session, 

342 ) 

343 merge_conn( 

344 Connection( 

345 conn_id="facebook_default", 

346 conn_type="facebook_social", 

347 extra=""" 

348 { "account_id": "<AD_ACCOUNT_ID>", 

349 "app_id": "<FACEBOOK_APP_ID>", 

350 "app_secret": "<FACEBOOK_APP_SECRET>", 

351 "access_token": "<FACEBOOK_AD_ACCESS_TOKEN>" 

352 } 

353 """, 

354 ), 

355 session, 

356 ) 

357 merge_conn( 

358 Connection( 

359 conn_id="fs_default", 

360 conn_type="fs", 

361 extra='{"path": "/"}', 

362 ), 

363 session, 

364 ) 

365 merge_conn( 

366 Connection( 

367 conn_id="ftp_default", 

368 conn_type="ftp", 

369 host="localhost", 

370 port=21, 

371 login="airflow", 

372 password="airflow", 

373 extra='{"key_file": "~/.ssh/id_rsa", "no_host_key_check": true}', 

374 ), 

375 session, 

376 ) 

377 merge_conn( 

378 Connection( 

379 conn_id="google_cloud_default", 

380 conn_type="google_cloud_platform", 

381 schema="default", 

382 ), 

383 session, 

384 ) 

385 merge_conn( 

386 Connection( 

387 conn_id="hive_cli_default", 

388 conn_type="hive_cli", 

389 port=10000, 

390 host="localhost", 

391 extra='{"use_beeline": true, "auth": ""}', 

392 schema="default", 

393 ), 

394 session, 

395 ) 

396 merge_conn( 

397 Connection( 

398 conn_id="hiveserver2_default", 

399 conn_type="hiveserver2", 

400 host="localhost", 

401 schema="default", 

402 port=10000, 

403 ), 

404 session, 

405 ) 

406 merge_conn( 

407 Connection( 

408 conn_id="http_default", 

409 conn_type="http", 

410 host="https://www.httpbin.org/", 

411 ), 

412 session, 

413 ) 

414 merge_conn( 

415 Connection( 

416 conn_id="iceberg_default", 

417 conn_type="iceberg", 

418 host="https://api.iceberg.io/ws/v1", 

419 ), 

420 session, 

421 ) 

422 merge_conn(Connection(conn_id="impala_default", conn_type="impala", host="localhost", port=21050)) 

423 merge_conn( 

424 Connection( 

425 conn_id="kafka_default", 

426 conn_type="kafka", 

427 extra=json.dumps({"bootstrap.servers": "broker:29092", "group.id": "my-group"}), 

428 ), 

429 session, 

430 ) 

431 merge_conn( 

432 Connection( 

433 conn_id="kubernetes_default", 

434 conn_type="kubernetes", 

435 ), 

436 session, 

437 ) 

438 merge_conn( 

439 Connection( 

440 conn_id="kylin_default", 

441 conn_type="kylin", 

442 host="localhost", 

443 port=7070, 

444 login="ADMIN", 

445 password="KYLIN", 

446 ), 

447 session, 

448 ) 

449 merge_conn( 

450 Connection( 

451 conn_id="leveldb_default", 

452 conn_type="leveldb", 

453 host="localhost", 

454 ), 

455 session, 

456 ) 

457 merge_conn(Connection(conn_id="livy_default", conn_type="livy", host="livy", port=8998), session) 

458 merge_conn( 

459 Connection( 

460 conn_id="local_mysql", 

461 conn_type="mysql", 

462 host="localhost", 

463 login="airflow", 

464 password="airflow", 

465 schema="airflow", 

466 ), 

467 session, 

468 ) 

469 merge_conn( 

470 Connection( 

471 conn_id="metastore_default", 

472 conn_type="hive_metastore", 

473 host="localhost", 

474 extra='{"authMechanism": "PLAIN"}', 

475 port=9083, 

476 ), 

477 session, 

478 ) 

479 merge_conn(Connection(conn_id="mongo_default", conn_type="mongo", host="mongo", port=27017), session) 

480 merge_conn( 

481 Connection( 

482 conn_id="mssql_default", 

483 conn_type="mssql", 

484 host="localhost", 

485 port=1433, 

486 ), 

487 session, 

488 ) 

489 merge_conn( 

490 Connection( 

491 conn_id="mysql_default", 

492 conn_type="mysql", 

493 login="root", 

494 schema="airflow", 

495 host="mysql", 

496 ), 

497 session, 

498 ) 

499 merge_conn( 

500 Connection( 

501 conn_id="opsgenie_default", 

502 conn_type="http", 

503 host="", 

504 password="", 

505 ), 

506 session, 

507 ) 

508 merge_conn( 

509 Connection( 

510 conn_id="oracle_default", 

511 conn_type="oracle", 

512 host="localhost", 

513 login="root", 

514 password="password", 

515 schema="schema", 

516 port=1521, 

517 ), 

518 session, 

519 ) 

520 merge_conn( 

521 Connection( 

522 conn_id="oss_default", 

523 conn_type="oss", 

524 extra="""{ 

525 "auth_type": "AK", 

526 "access_key_id": "<ACCESS_KEY_ID>", 

527 "access_key_secret": "<ACCESS_KEY_SECRET>", 

528 "region": "<YOUR_OSS_REGION>"} 

529 """, 

530 ), 

531 session, 

532 ) 

533 merge_conn( 

534 Connection( 

535 conn_id="pig_cli_default", 

536 conn_type="pig_cli", 

537 schema="default", 

538 ), 

539 session, 

540 ) 

541 merge_conn( 

542 Connection( 

543 conn_id="pinot_admin_default", 

544 conn_type="pinot", 

545 host="localhost", 

546 port=9000, 

547 ), 

548 session, 

549 ) 

550 merge_conn( 

551 Connection( 

552 conn_id="pinot_broker_default", 

553 conn_type="pinot", 

554 host="localhost", 

555 port=9000, 

556 extra='{"endpoint": "/query", "schema": "http"}', 

557 ), 

558 session, 

559 ) 

560 merge_conn( 

561 Connection( 

562 conn_id="postgres_default", 

563 conn_type="postgres", 

564 login="postgres", 

565 password="airflow", 

566 schema="airflow", 

567 host="postgres", 

568 ), 

569 session, 

570 ) 

571 merge_conn( 

572 Connection( 

573 conn_id="presto_default", 

574 conn_type="presto", 

575 host="localhost", 

576 schema="hive", 

577 port=3400, 

578 ), 

579 session, 

580 ) 

581 merge_conn( 

582 Connection( 

583 conn_id="qdrant_default", 

584 conn_type="qdrant", 

585 host="qdrant", 

586 port=6333, 

587 ), 

588 session, 

589 ) 

590 merge_conn( 

591 Connection( 

592 conn_id="redis_default", 

593 conn_type="redis", 

594 host="redis", 

595 port=6379, 

596 extra='{"db": 0}', 

597 ), 

598 session, 

599 ) 

600 merge_conn( 

601 Connection( 

602 conn_id="redshift_default", 

603 conn_type="redshift", 

604 extra="""{ 

605 "iam": true, 

606 "cluster_identifier": "<REDSHIFT_CLUSTER_IDENTIFIER>", 

607 "port": 5439, 

608 "profile": "default", 

609 "db_user": "awsuser", 

610 "database": "dev", 

611 "region": "" 

612}""", 

613 ), 

614 session, 

615 ) 

616 merge_conn( 

617 Connection( 

618 conn_id="salesforce_default", 

619 conn_type="salesforce", 

620 login="username", 

621 password="password", 

622 extra='{"security_token": "security_token"}', 

623 ), 

624 session, 

625 ) 

626 merge_conn( 

627 Connection( 

628 conn_id="segment_default", 

629 conn_type="segment", 

630 extra='{"write_key": "my-segment-write-key"}', 

631 ), 

632 session, 

633 ) 

634 merge_conn( 

635 Connection( 

636 conn_id="sftp_default", 

637 conn_type="sftp", 

638 host="localhost", 

639 port=22, 

640 login="airflow", 

641 extra='{"key_file": "~/.ssh/id_rsa", "no_host_key_check": true}', 

642 ), 

643 session, 

644 ) 

645 merge_conn( 

646 Connection( 

647 conn_id="spark_default", 

648 conn_type="spark", 

649 host="yarn", 

650 extra='{"queue": "root.default"}', 

651 ), 

652 session, 

653 ) 

654 merge_conn( 

655 Connection( 

656 conn_id="sqlite_default", 

657 conn_type="sqlite", 

658 host=os.path.join(gettempdir(), "sqlite_default.db"), 

659 ), 

660 session, 

661 ) 

662 merge_conn( 

663 Connection( 

664 conn_id="ssh_default", 

665 conn_type="ssh", 

666 host="localhost", 

667 ), 

668 session, 

669 ) 

670 merge_conn( 

671 Connection( 

672 conn_id="tableau_default", 

673 conn_type="tableau", 

674 host="https://tableau.server.url", 

675 login="user", 

676 password="password", 

677 extra='{"site_id": "my_site"}', 

678 ), 

679 session, 

680 ) 

681 merge_conn( 

682 Connection( 

683 conn_id="tabular_default", 

684 conn_type="tabular", 

685 host="https://api.tabulardata.io/ws/v1", 

686 ), 

687 session, 

688 ) 

689 merge_conn( 

690 Connection( 

691 conn_id="teradata_default", 

692 conn_type="teradata", 

693 host="localhost", 

694 login="user", 

695 password="password", 

696 schema="schema", 

697 ), 

698 session, 

699 ) 

700 merge_conn( 

701 Connection( 

702 conn_id="trino_default", 

703 conn_type="trino", 

704 host="localhost", 

705 schema="hive", 

706 port=3400, 

707 ), 

708 session, 

709 ) 

710 merge_conn( 

711 Connection( 

712 conn_id="vertica_default", 

713 conn_type="vertica", 

714 host="localhost", 

715 port=5433, 

716 ), 

717 session, 

718 ) 

719 merge_conn( 

720 Connection( 

721 conn_id="wasb_default", 

722 conn_type="wasb", 

723 extra='{"sas_token": null}', 

724 ), 

725 session, 

726 ) 

727 merge_conn( 

728 Connection( 

729 conn_id="webhdfs_default", 

730 conn_type="hdfs", 

731 host="localhost", 

732 port=50070, 

733 ), 

734 session, 

735 ) 

736 merge_conn( 

737 Connection( 

738 conn_id="yandexcloud_default", 

739 conn_type="yandexcloud", 

740 schema="default", 

741 ), 

742 session, 

743 ) 

744 

745 

746def _get_flask_db(sql_database_uri): 

747 from flask import Flask 

748 from flask_sqlalchemy import SQLAlchemy 

749 

750 from airflow.www.session import AirflowDatabaseSessionInterface 

751 

752 flask_app = Flask(__name__) 

753 flask_app.config["SQLALCHEMY_DATABASE_URI"] = sql_database_uri 

754 flask_app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] = False 

755 db = SQLAlchemy(flask_app) 

756 AirflowDatabaseSessionInterface(app=flask_app, db=db, table="session", key_prefix="") 

757 return db 

758 

759 

760def _create_db_from_orm(session): 

761 from alembic import command 

762 

763 from airflow.models.base import Base 

764 from airflow.providers.fab.auth_manager.models import Model 

765 

766 def _create_flask_session_tbl(sql_database_uri): 

767 db = _get_flask_db(sql_database_uri) 

768 db.create_all() 

769 

770 with create_global_lock(session=session, lock=DBLocks.MIGRATIONS): 

771 engine = session.get_bind().engine 

772 Base.metadata.create_all(engine) 

773 Model.metadata.create_all(engine) 

774 _create_flask_session_tbl(engine.url) 

775 # stamp the migration head 

776 config = _get_alembic_config() 

777 command.stamp(config, "head") 

778 

779 

780@provide_session 

781def initdb(session: Session = NEW_SESSION, load_connections: bool = True, use_migration_files: bool = False): 

782 """Initialize Airflow database.""" 

783 import_all_models() 

784 

785 db_exists = _get_current_revision(session) 

786 if db_exists or use_migration_files: 

787 upgradedb(session=session, use_migration_files=use_migration_files) 

788 else: 

789 _create_db_from_orm(session=session) 

790 if conf.getboolean("database", "LOAD_DEFAULT_CONNECTIONS") and load_connections: 

791 create_default_connections(session=session) 

792 # Add default pool & sync log_template 

793 add_default_pool_if_not_exists(session=session) 

794 synchronize_log_template(session=session) 

795 

796 

797def _get_alembic_config(): 

798 from alembic.config import Config 

799 

800 package_dir = os.path.dirname(airflow.__file__) 

801 directory = os.path.join(package_dir, "migrations") 

802 alembic_file = conf.get("database", "alembic_ini_file_path") 

803 if os.path.isabs(alembic_file): 

804 config = Config(alembic_file) 

805 else: 

806 config = Config(os.path.join(package_dir, alembic_file)) 

807 config.set_main_option("script_location", directory.replace("%", "%%")) 

808 config.set_main_option("sqlalchemy.url", settings.SQL_ALCHEMY_CONN.replace("%", "%%")) 

809 return config 

810 

811 

812def _get_script_object(config=None) -> ScriptDirectory: 

813 from alembic.script import ScriptDirectory 

814 

815 if not config: 

816 config = _get_alembic_config() 

817 return ScriptDirectory.from_config(config) 

818 

819 

820def _get_current_revision(session): 

821 from alembic.migration import MigrationContext 

822 

823 conn = session.connection() 

824 

825 migration_ctx = MigrationContext.configure(conn) 

826 

827 return migration_ctx.get_current_revision() 

828 

829 

830def check_migrations(timeout): 

831 """ 

832 Wait for all airflow migrations to complete. 

833 

834 :param timeout: Timeout for the migration in seconds 

835 :return: None 

836 """ 

837 timeout = timeout or 1 # run the loop at least 1 

838 with _configured_alembic_environment() as env: 

839 context = env.get_context() 

840 source_heads = None 

841 db_heads = None 

842 for ticker in range(timeout): 

843 source_heads = set(env.script.get_heads()) 

844 db_heads = set(context.get_current_heads()) 

845 if source_heads == db_heads: 

846 return 

847 time.sleep(1) 

848 log.info("Waiting for migrations... %s second(s)", ticker) 

849 raise TimeoutError( 

850 f"There are still unapplied migrations after {timeout} seconds. Migration" 

851 f"Head(s) in DB: {db_heads} | Migration Head(s) in Source Code: {source_heads}" 

852 ) 

853 

854 

855@contextlib.contextmanager 

856def _configured_alembic_environment() -> Generator[EnvironmentContext, None, None]: 

857 from alembic.runtime.environment import EnvironmentContext 

858 

859 config = _get_alembic_config() 

860 script = _get_script_object(config) 

861 

862 with EnvironmentContext( 

863 config, 

864 script, 

865 ) as env, settings.engine.connect() as connection: 

866 alembic_logger = logging.getLogger("alembic") 

867 level = alembic_logger.level 

868 alembic_logger.setLevel(logging.WARNING) 

869 env.configure(connection) 

870 alembic_logger.setLevel(level) 

871 

872 yield env 

873 

874 

875def check_and_run_migrations(): 

876 """Check and run migrations if necessary. Only use in a tty.""" 

877 with _configured_alembic_environment() as env: 

878 context = env.get_context() 

879 source_heads = set(env.script.get_heads()) 

880 db_heads = set(context.get_current_heads()) 

881 db_command = None 

882 command_name = None 

883 verb = None 

884 if len(db_heads) < 1: 

885 db_command = initdb 

886 command_name = "init" 

887 verb = "initialize" 

888 elif source_heads != db_heads: 

889 db_command = upgradedb 

890 command_name = "upgrade" 

891 verb = "upgrade" 

892 

893 if sys.stdout.isatty() and verb: 

894 print() 

895 question = f"Please confirm database {verb} (or wait 4 seconds to skip it). Are you sure? [y/N]" 

896 try: 

897 answer = helpers.prompt_with_timeout(question, timeout=4, default=False) 

898 if answer: 

899 try: 

900 db_command() 

901 print(f"DB {verb} done") 

902 except Exception as error: 

903 from airflow.version import version 

904 

905 print(error) 

906 print( 

907 "You still have unapplied migrations. " 

908 f"You may need to {verb} the database by running `airflow db {command_name}`. ", 

909 f"Make sure the command is run using Airflow version {version}.", 

910 file=sys.stderr, 

911 ) 

912 sys.exit(1) 

913 except AirflowException: 

914 pass 

915 elif source_heads != db_heads: 

916 from airflow.version import version 

917 

918 print( 

919 f"ERROR: You need to {verb} the database. Please run `airflow db {command_name}`. " 

920 f"Make sure the command is run using Airflow version {version}.", 

921 file=sys.stderr, 

922 ) 

923 sys.exit(1) 

924 

925 

926def _reserialize_dags(*, session: Session) -> None: 

927 from airflow.models.dagbag import DagBag 

928 from airflow.models.serialized_dag import SerializedDagModel 

929 

930 session.execute(delete(SerializedDagModel).execution_options(synchronize_session=False)) 

931 dagbag = DagBag(collect_dags=False) 

932 dagbag.collect_dags(only_if_updated=False) 

933 dagbag.sync_to_db(session=session) 

934 

935 

936@provide_session 

937def synchronize_log_template(*, session: Session = NEW_SESSION) -> None: 

938 """Synchronize log template configs with table. 

939 

940 This checks if the last row fully matches the current config values, and 

941 insert a new row if not. 

942 """ 

943 # NOTE: SELECT queries in this function are INTENTIONALLY written with the 

944 # SQL builder style, not the ORM query API. This avoids configuring the ORM 

945 # unless we need to insert something, speeding up CLI in general. 

946 

947 from airflow.models.tasklog import LogTemplate 

948 

949 metadata = reflect_tables([LogTemplate], session) 

950 log_template_table: Table | None = metadata.tables.get(LogTemplate.__tablename__) 

951 

952 if log_template_table is None: 

953 log.info("Log template table does not exist (added in 2.3.0); skipping log template sync.") 

954 return 

955 

956 filename = conf.get("logging", "log_filename_template") 

957 elasticsearch_id = conf.get("elasticsearch", "log_id_template") 

958 

959 stored = session.execute( 

960 select( 

961 log_template_table.c.filename, 

962 log_template_table.c.elasticsearch_id, 

963 ) 

964 .order_by(log_template_table.c.id.desc()) 

965 .limit(1) 

966 ).first() 

967 

968 # If we have an empty table, and the default values exist, we will seed the 

969 # table with values from pre 2.3.0, so old logs will still be retrievable. 

970 if not stored: 

971 is_default_log_id = elasticsearch_id == conf.get_default_value("elasticsearch", "log_id_template") 

972 is_default_filename = filename == conf.get_default_value("logging", "log_filename_template") 

973 if is_default_log_id and is_default_filename: 

974 session.add( 

975 LogTemplate( 

976 filename="{{ ti.dag_id }}/{{ ti.task_id }}/{{ ts }}/{{ try_number }}.log", 

977 elasticsearch_id="{dag_id}-{task_id}-{execution_date}-{try_number}", 

978 ) 

979 ) 

980 

981 # Before checking if the _current_ value exists, we need to check if the old config value we upgraded in 

982 # place exists! 

983 pre_upgrade_filename = conf.upgraded_values.get(("logging", "log_filename_template"), filename) 

984 pre_upgrade_elasticsearch_id = conf.upgraded_values.get( 

985 ("elasticsearch", "log_id_template"), elasticsearch_id 

986 ) 

987 if pre_upgrade_filename != filename or pre_upgrade_elasticsearch_id != elasticsearch_id: 

988 # The previous non-upgraded value likely won't be the _latest_ value (as after we've recorded the 

989 # recorded the upgraded value it will be second-to-newest), so we'll have to just search which is okay 

990 # as this is a table with a tiny number of rows 

991 row = session.execute( 

992 select(log_template_table.c.id) 

993 .where( 

994 or_( 

995 log_template_table.c.filename == pre_upgrade_filename, 

996 log_template_table.c.elasticsearch_id == pre_upgrade_elasticsearch_id, 

997 ) 

998 ) 

999 .order_by(log_template_table.c.id.desc()) 

1000 .limit(1) 

1001 ).first() 

1002 if not row: 

1003 session.add( 

1004 LogTemplate(filename=pre_upgrade_filename, elasticsearch_id=pre_upgrade_elasticsearch_id) 

1005 ) 

1006 

1007 if not stored or stored.filename != filename or stored.elasticsearch_id != elasticsearch_id: 

1008 session.add(LogTemplate(filename=filename, elasticsearch_id=elasticsearch_id)) 

1009 

1010 

1011def check_conn_id_duplicates(session: Session) -> Iterable[str]: 

1012 """ 

1013 Check unique conn_id in connection table. 

1014 

1015 :param session: session of the sqlalchemy 

1016 """ 

1017 from airflow.models.connection import Connection 

1018 

1019 try: 

1020 dups = session.scalars( 

1021 select(Connection.conn_id).group_by(Connection.conn_id).having(func.count() > 1) 

1022 ).all() 

1023 except (exc.OperationalError, exc.ProgrammingError): 

1024 # fallback if tables hasn't been created yet 

1025 session.rollback() 

1026 return 

1027 if dups: 

1028 yield ( 

1029 "Seems you have non unique conn_id in connection table.\n" 

1030 "You have to manage those duplicate connections " 

1031 "before upgrading the database.\n" 

1032 f"Duplicated conn_id: {dups}" 

1033 ) 

1034 

1035 

1036def check_username_duplicates(session: Session) -> Iterable[str]: 

1037 """ 

1038 Check unique username in User & RegisterUser table. 

1039 

1040 :param session: session of the sqlalchemy 

1041 :rtype: str 

1042 """ 

1043 from airflow.providers.fab.auth_manager.models import RegisterUser, User 

1044 

1045 for model in [User, RegisterUser]: 

1046 dups = [] 

1047 try: 

1048 dups = session.execute( 

1049 select(model.username) # type: ignore[attr-defined] 

1050 .group_by(model.username) # type: ignore[attr-defined] 

1051 .having(func.count() > 1) 

1052 ).all() 

1053 except (exc.OperationalError, exc.ProgrammingError): 

1054 # fallback if tables hasn't been created yet 

1055 session.rollback() 

1056 if dups: 

1057 yield ( 

1058 f"Seems you have mixed case usernames in {model.__table__.name} table.\n" # type: ignore 

1059 "You have to rename or delete those mixed case usernames " 

1060 "before upgrading the database.\n" 

1061 f"usernames with mixed cases: {[dup.username for dup in dups]}" 

1062 ) 

1063 

1064 

1065def reflect_tables(tables: list[MappedClassProtocol | str] | None, session): 

1066 """ 

1067 When running checks prior to upgrades, we use reflection to determine current state of the database. 

1068 

1069 This function gets the current state of each table in the set of models 

1070 provided and returns a SqlAlchemy metadata object containing them. 

1071 """ 

1072 import sqlalchemy.schema 

1073 

1074 bind = session.bind 

1075 metadata = sqlalchemy.schema.MetaData() 

1076 

1077 if tables is None: 

1078 metadata.reflect(bind=bind, resolve_fks=False) 

1079 else: 

1080 for tbl in tables: 

1081 try: 

1082 table_name = tbl if isinstance(tbl, str) else tbl.__tablename__ 

1083 metadata.reflect(bind=bind, only=[table_name], extend_existing=True, resolve_fks=False) 

1084 except exc.InvalidRequestError: 

1085 continue 

1086 return metadata 

1087 

1088 

1089def check_table_for_duplicates( 

1090 *, session: Session, table_name: str, uniqueness: list[str], version: str 

1091) -> Iterable[str]: 

1092 """ 

1093 Check table for duplicates, given a list of columns which define the uniqueness of the table. 

1094 

1095 Usage example: 

1096 

1097 .. code-block:: python 

1098 

1099 def check_task_fail_for_duplicates(session): 

1100 from airflow.models.taskfail import TaskFail 

1101 

1102 metadata = reflect_tables([TaskFail], session) 

1103 task_fail = metadata.tables.get(TaskFail.__tablename__) # type: ignore 

1104 if task_fail is None: # table not there 

1105 return 

1106 if "run_id" in task_fail.columns: # upgrade already applied 

1107 return 

1108 yield from check_table_for_duplicates( 

1109 table_name=task_fail.name, 

1110 uniqueness=["dag_id", "task_id", "execution_date"], 

1111 session=session, 

1112 version="2.3", 

1113 ) 

1114 

1115 :param table_name: table name to check 

1116 :param uniqueness: uniqueness constraint to evaluate against 

1117 :param session: session of the sqlalchemy 

1118 """ 

1119 minimal_table_obj = table(table_name, *(column(x) for x in uniqueness)) 

1120 try: 

1121 subquery = session.execute( 

1122 select(minimal_table_obj, func.count().label("dupe_count")) 

1123 .group_by(*(text(x) for x in uniqueness)) 

1124 .having(func.count() > text("1")) 

1125 .subquery() 

1126 ) 

1127 dupe_count = session.scalar(select(func.sum(subquery.c.dupe_count))) 

1128 if not dupe_count: 

1129 # there are no duplicates; nothing to do. 

1130 return 

1131 

1132 log.warning("Found %s duplicates in table %s. Will attempt to move them.", dupe_count, table_name) 

1133 

1134 metadata = reflect_tables(tables=[table_name], session=session) 

1135 if table_name not in metadata.tables: 

1136 yield f"Table {table_name} does not exist in the database." 

1137 

1138 # We can't use the model here since it may differ from the db state due to 

1139 # this function is run prior to migration. Use the reflected table instead. 

1140 table_obj = metadata.tables[table_name] 

1141 

1142 _move_duplicate_data_to_new_table( 

1143 session=session, 

1144 source_table=table_obj, 

1145 subquery=subquery, 

1146 uniqueness=uniqueness, 

1147 target_table_name=_format_airflow_moved_table_name(table_name, version, "duplicates"), 

1148 ) 

1149 except (exc.OperationalError, exc.ProgrammingError): 

1150 # fallback if `table_name` hasn't been created yet 

1151 session.rollback() 

1152 

1153 

1154def check_conn_type_null(session: Session) -> Iterable[str]: 

1155 """ 

1156 Check nullable conn_type column in Connection table. 

1157 

1158 :param session: session of the sqlalchemy 

1159 """ 

1160 from airflow.models.connection import Connection 

1161 

1162 try: 

1163 n_nulls = session.scalars(select(Connection.conn_id).where(Connection.conn_type.is_(None))).all() 

1164 except (exc.OperationalError, exc.ProgrammingError, exc.InternalError): 

1165 # fallback if tables hasn't been created yet 

1166 session.rollback() 

1167 return 

1168 

1169 if n_nulls: 

1170 yield ( 

1171 "The conn_type column in the connection " 

1172 "table must contain content.\n" 

1173 "Make sure you don't have null " 

1174 "in the conn_type column.\n" 

1175 f"Null conn_type conn_id: {n_nulls}" 

1176 ) 

1177 

1178 

1179def _format_dangling_error(source_table, target_table, invalid_count, reason): 

1180 noun = "row" if invalid_count == 1 else "rows" 

1181 return ( 

1182 f"The {source_table} table has {invalid_count} {noun} {reason}, which " 

1183 f"is invalid. We could not move them out of the way because the " 

1184 f"{target_table} table already exists in your database. Please either " 

1185 f"drop the {target_table} table, or manually delete the invalid rows " 

1186 f"from the {source_table} table." 

1187 ) 

1188 

1189 

1190def check_run_id_null(session: Session) -> Iterable[str]: 

1191 from airflow.models.dagrun import DagRun 

1192 

1193 metadata = reflect_tables([DagRun], session) 

1194 

1195 # We can't use the model here since it may differ from the db state due to 

1196 # this function is run prior to migration. Use the reflected table instead. 

1197 dagrun_table = metadata.tables.get(DagRun.__tablename__) 

1198 if dagrun_table is None: 

1199 return 

1200 

1201 invalid_dagrun_filter = or_( 

1202 dagrun_table.c.dag_id.is_(None), 

1203 dagrun_table.c.run_id.is_(None), 

1204 dagrun_table.c.execution_date.is_(None), 

1205 ) 

1206 invalid_dagrun_count = session.scalar(select(func.count(dagrun_table.c.id)).where(invalid_dagrun_filter)) 

1207 if invalid_dagrun_count > 0: 

1208 dagrun_dangling_table_name = _format_airflow_moved_table_name(dagrun_table.name, "2.2", "dangling") 

1209 if dagrun_dangling_table_name in inspect(session.get_bind()).get_table_names(): 

1210 yield _format_dangling_error( 

1211 source_table=dagrun_table.name, 

1212 target_table=dagrun_dangling_table_name, 

1213 invalid_count=invalid_dagrun_count, 

1214 reason="with a NULL dag_id, run_id, or execution_date", 

1215 ) 

1216 return 

1217 

1218 bind = session.get_bind() 

1219 dialect_name = bind.dialect.name 

1220 _create_table_as( 

1221 dialect_name=dialect_name, 

1222 source_query=dagrun_table.select(invalid_dagrun_filter), 

1223 target_table_name=dagrun_dangling_table_name, 

1224 source_table_name=dagrun_table.name, 

1225 session=session, 

1226 ) 

1227 delete = dagrun_table.delete().where(invalid_dagrun_filter) 

1228 session.execute(delete) 

1229 

1230 

1231def _create_table_as( 

1232 *, 

1233 session, 

1234 dialect_name: str, 

1235 source_query: Query, 

1236 target_table_name: str, 

1237 source_table_name: str, 

1238): 

1239 """ 

1240 Create a new table with rows from query. 

1241 

1242 We have to handle CTAS differently for different dialects. 

1243 """ 

1244 if dialect_name == "mysql": 

1245 # MySQL with replication needs this split in to two queries, so just do it for all MySQL 

1246 # ERROR 1786 (HY000): Statement violates GTID consistency: CREATE TABLE ... SELECT. 

1247 session.execute(text(f"CREATE TABLE {target_table_name} LIKE {source_table_name}")) 

1248 session.execute( 

1249 text( 

1250 f"INSERT INTO {target_table_name} {source_query.selectable.compile(bind=session.get_bind())}" 

1251 ) 

1252 ) 

1253 else: 

1254 # Postgres and SQLite both support the same "CREATE TABLE a AS SELECT ..." syntax 

1255 select_table = source_query.selectable.compile(bind=session.get_bind()) 

1256 session.execute(text(f"CREATE TABLE {target_table_name} AS {select_table}")) 

1257 

1258 

1259def _move_dangling_data_to_new_table( 

1260 session, source_table: Table, source_query: Query, target_table_name: str 

1261): 

1262 bind = session.get_bind() 

1263 dialect_name = bind.dialect.name 

1264 

1265 # First: Create moved rows from new table 

1266 log.debug("running CTAS for table %s", target_table_name) 

1267 _create_table_as( 

1268 dialect_name=dialect_name, 

1269 source_query=source_query, 

1270 target_table_name=target_table_name, 

1271 source_table_name=source_table.name, 

1272 session=session, 

1273 ) 

1274 session.commit() 

1275 

1276 target_table = source_table.to_metadata(source_table.metadata, name=target_table_name) 

1277 log.debug("checking whether rows were moved for table %s", target_table_name) 

1278 moved_rows_exist_query = select(1).select_from(target_table).limit(1) 

1279 first_moved_row = session.execute(moved_rows_exist_query).all() 

1280 session.commit() 

1281 

1282 if not first_moved_row: 

1283 log.debug("no rows moved; dropping %s", target_table_name) 

1284 # no bad rows were found; drop moved rows table. 

1285 target_table.drop(bind=session.get_bind(), checkfirst=True) 

1286 else: 

1287 log.debug("rows moved; purging from %s", source_table.name) 

1288 if dialect_name == "sqlite": 

1289 pk_cols = source_table.primary_key.columns 

1290 

1291 delete = source_table.delete().where( 

1292 tuple_(*pk_cols).in_(session.select(*target_table.primary_key.columns).subquery()) 

1293 ) 

1294 else: 

1295 delete = source_table.delete().where( 

1296 and_(col == target_table.c[col.name] for col in source_table.primary_key.columns) 

1297 ) 

1298 log.debug(delete.compile()) 

1299 session.execute(delete) 

1300 session.commit() 

1301 

1302 log.debug("exiting move function") 

1303 

1304 

1305def _dangling_against_dag_run(session, source_table, dag_run): 

1306 """Given a source table, we generate a subquery that will return 1 for every row that has a dagrun.""" 

1307 source_to_dag_run_join_cond = and_( 

1308 source_table.c.dag_id == dag_run.c.dag_id, 

1309 source_table.c.execution_date == dag_run.c.execution_date, 

1310 ) 

1311 

1312 return ( 

1313 select(*(c.label(c.name) for c in source_table.c)) 

1314 .join(dag_run, source_to_dag_run_join_cond, isouter=True) 

1315 .where(dag_run.c.dag_id.is_(None)) 

1316 ) 

1317 

1318 

1319def _dangling_against_task_instance(session, source_table, dag_run, task_instance): 

1320 """ 

1321 Given a source table, generate a subquery that will return 1 for every row that has a valid task instance. 

1322 

1323 This is used to identify rows that need to be removed from tables prior to adding a TI fk. 

1324 

1325 Since this check is applied prior to running the migrations, we have to use different 

1326 query logic depending on which revision the database is at. 

1327 

1328 """ 

1329 if "run_id" not in task_instance.c: 

1330 # db is < 2.2.0 

1331 dr_join_cond = and_( 

1332 source_table.c.dag_id == dag_run.c.dag_id, 

1333 source_table.c.execution_date == dag_run.c.execution_date, 

1334 ) 

1335 ti_join_cond = and_( 

1336 dag_run.c.dag_id == task_instance.c.dag_id, 

1337 dag_run.c.execution_date == task_instance.c.execution_date, 

1338 source_table.c.task_id == task_instance.c.task_id, 

1339 ) 

1340 else: 

1341 # db is 2.2.0 <= version < 2.3.0 

1342 dr_join_cond = and_( 

1343 source_table.c.dag_id == dag_run.c.dag_id, 

1344 source_table.c.execution_date == dag_run.c.execution_date, 

1345 ) 

1346 ti_join_cond = and_( 

1347 dag_run.c.dag_id == task_instance.c.dag_id, 

1348 dag_run.c.run_id == task_instance.c.run_id, 

1349 source_table.c.task_id == task_instance.c.task_id, 

1350 ) 

1351 

1352 return ( 

1353 select(*(c.label(c.name) for c in source_table.c)) 

1354 .outerjoin(dag_run, dr_join_cond) 

1355 .outerjoin(task_instance, ti_join_cond) 

1356 .where(or_(task_instance.c.dag_id.is_(None), dag_run.c.dag_id.is_(None))) 

1357 ) 

1358 

1359 

1360def _move_duplicate_data_to_new_table( 

1361 session, source_table: Table, subquery: Query, uniqueness: list[str], target_table_name: str 

1362): 

1363 """ 

1364 When adding a uniqueness constraint we first should ensure that there are no duplicate rows. 

1365 

1366 This function accepts a subquery that should return one record for each row with duplicates (e.g. 

1367 a group by with having count(*) > 1). We select from ``source_table`` getting all rows matching the 

1368 subquery result and store in ``target_table_name``. Then to purge the duplicates from the source table, 

1369 we do a DELETE FROM with a join to the target table (which now contains the dupes). 

1370 

1371 :param session: sqlalchemy session for metadata db 

1372 :param source_table: table to purge dupes from 

1373 :param subquery: the subquery that returns the duplicate rows 

1374 :param uniqueness: the string list of columns used to define the uniqueness for the table. used in 

1375 building the DELETE FROM join condition. 

1376 :param target_table_name: name of the table in which to park the duplicate rows 

1377 """ 

1378 bind = session.get_bind() 

1379 dialect_name = bind.dialect.name 

1380 

1381 query = ( 

1382 select(*(source_table.c[x.name].label(str(x.name)) for x in source_table.columns)) 

1383 .select_from(source_table) 

1384 .join(subquery, and_(*(source_table.c[x] == subquery.c[x] for x in uniqueness))) 

1385 ) 

1386 

1387 _create_table_as( 

1388 session=session, 

1389 dialect_name=dialect_name, 

1390 source_query=query, 

1391 target_table_name=target_table_name, 

1392 source_table_name=source_table.name, 

1393 ) 

1394 

1395 # we must ensure that the CTAS table is created prior to the DELETE step since we have to join to it 

1396 session.commit() 

1397 

1398 metadata = reflect_tables([target_table_name], session) 

1399 target_table = metadata.tables[target_table_name] 

1400 where_clause = and_(*(source_table.c[x] == target_table.c[x] for x in uniqueness)) 

1401 

1402 if dialect_name == "sqlite": 

1403 subq = query.selectable.with_only_columns([text(f"{source_table}.ROWID")]) 

1404 delete = source_table.delete().where(column("ROWID").in_(subq)) 

1405 else: 

1406 delete = source_table.delete(where_clause) 

1407 

1408 session.execute(delete) 

1409 

1410 

1411def check_bad_references(session: Session) -> Iterable[str]: 

1412 """ 

1413 Go through each table and look for records that can't be mapped to a dag run. 

1414 

1415 When we find such "dangling" rows we back them up in a special table and delete them 

1416 from the main table. 

1417 

1418 Starting in Airflow 2.2, we began a process of replacing `execution_date` with `run_id` in many tables. 

1419 """ 

1420 from airflow.models.dagrun import DagRun 

1421 from airflow.models.renderedtifields import RenderedTaskInstanceFields 

1422 from airflow.models.taskfail import TaskFail 

1423 from airflow.models.taskinstance import TaskInstance 

1424 from airflow.models.taskreschedule import TaskReschedule 

1425 from airflow.models.xcom import XCom 

1426 

1427 @dataclass 

1428 class BadReferenceConfig: 

1429 """ 

1430 Bad reference config class. 

1431 

1432 :param bad_rows_func: function that returns subquery which determines whether bad rows exist 

1433 :param join_tables: table objects referenced in subquery 

1434 :param ref_table: information-only identifier for categorizing the missing ref 

1435 """ 

1436 

1437 bad_rows_func: Callable 

1438 join_tables: list[str] 

1439 ref_table: str 

1440 

1441 missing_dag_run_config = BadReferenceConfig( 

1442 bad_rows_func=_dangling_against_dag_run, 

1443 join_tables=["dag_run"], 

1444 ref_table="dag_run", 

1445 ) 

1446 

1447 missing_ti_config = BadReferenceConfig( 

1448 bad_rows_func=_dangling_against_task_instance, 

1449 join_tables=["dag_run", "task_instance"], 

1450 ref_table="task_instance", 

1451 ) 

1452 

1453 models_list: list[tuple[MappedClassProtocol, str, BadReferenceConfig]] = [ 

1454 (TaskInstance, "2.2", missing_dag_run_config), 

1455 (TaskReschedule, "2.2", missing_ti_config), 

1456 (RenderedTaskInstanceFields, "2.3", missing_ti_config), 

1457 (TaskFail, "2.3", missing_ti_config), 

1458 (XCom, "2.3", missing_ti_config), 

1459 ] 

1460 metadata = reflect_tables([*(x[0] for x in models_list), DagRun, TaskInstance], session) 

1461 

1462 if ( 

1463 not metadata.tables 

1464 or metadata.tables.get(DagRun.__tablename__) is None 

1465 or metadata.tables.get(TaskInstance.__tablename__) is None 

1466 ): 

1467 # Key table doesn't exist -- likely empty DB. 

1468 return 

1469 

1470 existing_table_names = set(inspect(session.get_bind()).get_table_names()) 

1471 errored = False 

1472 

1473 for model, change_version, bad_ref_cfg in models_list: 

1474 log.debug("checking model %s", model.__tablename__) 

1475 # We can't use the model here since it may differ from the db state due to 

1476 # this function is run prior to migration. Use the reflected table instead. 

1477 source_table = metadata.tables.get(model.__tablename__) # type: ignore 

1478 if source_table is None: 

1479 continue 

1480 

1481 # Migration already applied, don't check again. 

1482 if "run_id" in source_table.columns: 

1483 continue 

1484 

1485 func_kwargs = {x: metadata.tables[x] for x in bad_ref_cfg.join_tables} 

1486 bad_rows_query = bad_ref_cfg.bad_rows_func(session, source_table, **func_kwargs) 

1487 

1488 dangling_table_name = _format_airflow_moved_table_name(source_table.name, change_version, "dangling") 

1489 if dangling_table_name in existing_table_names: 

1490 invalid_row_count = get_query_count(bad_rows_query, session=session) 

1491 if invalid_row_count: 

1492 yield _format_dangling_error( 

1493 source_table=source_table.name, 

1494 target_table=dangling_table_name, 

1495 invalid_count=invalid_row_count, 

1496 reason=f"without a corresponding {bad_ref_cfg.ref_table} row", 

1497 ) 

1498 errored = True 

1499 continue 

1500 

1501 log.debug("moving data for table %s", source_table.name) 

1502 _move_dangling_data_to_new_table( 

1503 session, 

1504 source_table, 

1505 bad_rows_query, 

1506 dangling_table_name, 

1507 ) 

1508 

1509 if errored: 

1510 session.rollback() 

1511 else: 

1512 session.commit() 

1513 

1514 

1515@provide_session 

1516def _check_migration_errors(session: Session = NEW_SESSION) -> Iterable[str]: 

1517 """:session: session of the sqlalchemy.""" 

1518 check_functions: tuple[Callable[..., Iterable[str]], ...] = ( 

1519 check_conn_id_duplicates, 

1520 check_conn_type_null, 

1521 check_run_id_null, 

1522 check_bad_references, 

1523 check_username_duplicates, 

1524 ) 

1525 for check_fn in check_functions: 

1526 log.debug("running check function %s", check_fn.__name__) 

1527 yield from check_fn(session=session) 

1528 

1529 

1530def _offline_migration(migration_func: Callable, config, revision): 

1531 with warnings.catch_warnings(): 

1532 warnings.simplefilter("ignore") 

1533 logging.disable(logging.CRITICAL) 

1534 migration_func(config, revision, sql=True) 

1535 logging.disable(logging.NOTSET) 

1536 

1537 

1538def print_happy_cat(message): 

1539 if sys.stdout.isatty(): 

1540 size = os.get_terminal_size().columns 

1541 else: 

1542 size = 0 

1543 print(message.center(size)) 

1544 print("""/\\_/\\""".center(size)) 

1545 print("""(='_' )""".center(size)) 

1546 print("""(,(") (")""".center(size)) 

1547 print("""^^^""".center(size)) 

1548 return 

1549 

1550 

1551def _revision_greater(config, this_rev, base_rev): 

1552 # Check if there is history between the revisions and the start revision 

1553 # This ensures that the revisions are above `min_revision` 

1554 script = _get_script_object(config) 

1555 try: 

1556 list(script.revision_map.iterate_revisions(upper=this_rev, lower=base_rev)) 

1557 return True 

1558 except Exception: 

1559 return False 

1560 

1561 

1562def _revisions_above_min_for_offline(config, revisions) -> None: 

1563 """ 

1564 Check that all supplied revision ids are above the minimum revision for the dialect. 

1565 

1566 :param config: Alembic config 

1567 :param revisions: list of Alembic revision ids 

1568 :return: None 

1569 """ 

1570 dbname = settings.engine.dialect.name 

1571 if dbname == "sqlite": 

1572 raise SystemExit("Offline migration not supported for SQLite.") 

1573 min_version, min_revision = ("2.2.0", "7b2661a43ba3") if dbname == "mssql" else ("2.0.0", "e959f08ac86c") 

1574 

1575 # Check if there is history between the revisions and the start revision 

1576 # This ensures that the revisions are above `min_revision` 

1577 for rev in revisions: 

1578 if not _revision_greater(config, rev, min_revision): 

1579 raise ValueError( 

1580 f"Error while checking history for revision range {min_revision}:{rev}. " 

1581 f"Check that {rev} is a valid revision. " 

1582 f"For dialect {dbname!r}, supported revision for offline migration is from {min_revision} " 

1583 f"which corresponds to Airflow {min_version}." 

1584 ) 

1585 

1586 

1587@provide_session 

1588def upgradedb( 

1589 *, 

1590 to_revision: str | None = None, 

1591 from_revision: str | None = None, 

1592 show_sql_only: bool = False, 

1593 reserialize_dags: bool = True, 

1594 session: Session = NEW_SESSION, 

1595 use_migration_files: bool = False, 

1596): 

1597 """ 

1598 Upgrades the DB. 

1599 

1600 :param to_revision: Optional Alembic revision ID to upgrade *to*. 

1601 If omitted, upgrades to latest revision. 

1602 :param from_revision: Optional Alembic revision ID to upgrade *from*. 

1603 Not compatible with ``sql_only=False``. 

1604 :param show_sql_only: if True, migration statements will be printed but not executed. 

1605 :param session: sqlalchemy session with connection to Airflow metadata database 

1606 :return: None 

1607 """ 

1608 if from_revision and not show_sql_only: 

1609 raise AirflowException("`from_revision` only supported with `sql_only=True`.") 

1610 

1611 # alembic adds significant import time, so we import it lazily 

1612 if not settings.SQL_ALCHEMY_CONN: 

1613 raise RuntimeError("The settings.SQL_ALCHEMY_CONN not set. This is a critical assertion.") 

1614 from alembic import command 

1615 

1616 import_all_models() 

1617 

1618 config = _get_alembic_config() 

1619 

1620 if show_sql_only: 

1621 if not from_revision: 

1622 from_revision = _get_current_revision(session) 

1623 

1624 if not to_revision: 

1625 script = _get_script_object() 

1626 to_revision = script.get_current_head() 

1627 

1628 if to_revision == from_revision: 

1629 print_happy_cat("No migrations to apply; nothing to do.") 

1630 return 

1631 

1632 if not _revision_greater(config, to_revision, from_revision): 

1633 raise ValueError( 

1634 f"Requested *to* revision {to_revision} is older than *from* revision {from_revision}. " 

1635 "Please check your requested versions / revisions." 

1636 ) 

1637 _revisions_above_min_for_offline(config=config, revisions=[from_revision, to_revision]) 

1638 

1639 _offline_migration(command.upgrade, config, f"{from_revision}:{to_revision}") 

1640 return # only running sql; our job is done 

1641 

1642 errors_seen = False 

1643 for err in _check_migration_errors(session=session): 

1644 if not errors_seen: 

1645 log.error("Automatic migration is not available") 

1646 errors_seen = True 

1647 log.error("%s", err) 

1648 

1649 if errors_seen: 

1650 exit(1) 

1651 

1652 if not to_revision and not _get_current_revision(session=session) and not use_migration_files: 

1653 # Don't load default connections 

1654 # New DB; initialize and exit 

1655 initdb(session=session, load_connections=False) 

1656 return 

1657 with create_global_lock(session=session, lock=DBLocks.MIGRATIONS): 

1658 import sqlalchemy.pool 

1659 

1660 log.info("Creating tables") 

1661 val = os.environ.get("AIRFLOW__DATABASE__SQL_ALCHEMY_MAX_SIZE") 

1662 try: 

1663 # Reconfigure the ORM to use _EXACTLY_ one connection, otherwise some db engines hang forever 

1664 # trying to ALTER TABLEs 

1665 os.environ["AIRFLOW__DATABASE__SQL_ALCHEMY_MAX_SIZE"] = "1" 

1666 settings.reconfigure_orm(pool_class=sqlalchemy.pool.SingletonThreadPool) 

1667 command.upgrade(config, revision=to_revision or "heads") 

1668 finally: 

1669 if val is None: 

1670 os.environ.pop("AIRFLOW__DATABASE__SQL_ALCHEMY_MAX_SIZE") 

1671 else: 

1672 os.environ["AIRFLOW__DATABASE__SQL_ALCHEMY_MAX_SIZE"] = val 

1673 settings.reconfigure_orm() 

1674 

1675 if reserialize_dags: 

1676 _reserialize_dags(session=session) 

1677 add_default_pool_if_not_exists(session=session) 

1678 synchronize_log_template(session=session) 

1679 

1680 

1681@provide_session 

1682def resetdb(session: Session = NEW_SESSION, skip_init: bool = False, use_migration_files: bool = False): 

1683 """Clear out the database.""" 

1684 if not settings.engine: 

1685 raise RuntimeError("The settings.engine must be set. This is a critical assertion") 

1686 log.info("Dropping tables that exist") 

1687 

1688 import_all_models() 

1689 

1690 connection = settings.engine.connect() 

1691 

1692 with create_global_lock(session=session, lock=DBLocks.MIGRATIONS), connection.begin(): 

1693 drop_airflow_models(connection) 

1694 drop_airflow_moved_tables(connection) 

1695 

1696 if not skip_init: 

1697 initdb(session=session, use_migration_files=use_migration_files) 

1698 

1699 

1700@provide_session 

1701def bootstrap_dagbag(session: Session = NEW_SESSION): 

1702 from airflow.models.dag import DAG 

1703 from airflow.models.dagbag import DagBag 

1704 

1705 dagbag = DagBag() 

1706 # Save DAGs in the ORM 

1707 dagbag.sync_to_db(session=session) 

1708 

1709 # Deactivate the unknown ones 

1710 DAG.deactivate_unknown_dags(dagbag.dags.keys(), session=session) 

1711 

1712 

1713@provide_session 

1714def downgrade(*, to_revision, from_revision=None, show_sql_only=False, session: Session = NEW_SESSION): 

1715 """ 

1716 Downgrade the airflow metastore schema to a prior version. 

1717 

1718 :param to_revision: The alembic revision to downgrade *to*. 

1719 :param show_sql_only: if True, print sql statements but do not run them 

1720 :param from_revision: if supplied, alembic revision to dawngrade *from*. This may only 

1721 be used in conjunction with ``sql=True`` because if we actually run the commands, 

1722 we should only downgrade from the *current* revision. 

1723 :param session: sqlalchemy session for connection to airflow metadata database 

1724 """ 

1725 if from_revision and not show_sql_only: 

1726 raise ValueError( 

1727 "`from_revision` can't be combined with `sql=False`. When actually " 

1728 "applying a downgrade (instead of just generating sql), we always " 

1729 "downgrade from current revision." 

1730 ) 

1731 

1732 if not settings.SQL_ALCHEMY_CONN: 

1733 raise RuntimeError("The settings.SQL_ALCHEMY_CONN not set.") 

1734 

1735 # alembic adds significant import time, so we import it lazily 

1736 from alembic import command 

1737 

1738 log.info("Attempting downgrade to revision %s", to_revision) 

1739 config = _get_alembic_config() 

1740 

1741 with create_global_lock(session=session, lock=DBLocks.MIGRATIONS): 

1742 if show_sql_only: 

1743 log.warning("Generating sql scripts for manual migration.") 

1744 if not from_revision: 

1745 from_revision = _get_current_revision(session) 

1746 revision_range = f"{from_revision}:{to_revision}" 

1747 _offline_migration(command.downgrade, config=config, revision=revision_range) 

1748 else: 

1749 log.info("Applying downgrade migrations.") 

1750 command.downgrade(config, revision=to_revision, sql=show_sql_only) 

1751 

1752 

1753def drop_airflow_models(connection): 

1754 """ 

1755 Drop all airflow models. 

1756 

1757 :param connection: SQLAlchemy Connection 

1758 :return: None 

1759 """ 

1760 from airflow.models.base import Base 

1761 from airflow.providers.fab.auth_manager.models import Model 

1762 

1763 Base.metadata.drop_all(connection) 

1764 Model.metadata.drop_all(connection) 

1765 db = _get_flask_db(connection.engine.url) 

1766 db.drop_all() 

1767 # alembic adds significant import time, so we import it lazily 

1768 from alembic.migration import MigrationContext 

1769 

1770 migration_ctx = MigrationContext.configure(connection) 

1771 version = migration_ctx._version 

1772 if inspect(connection).has_table(version.name): 

1773 version.drop(connection) 

1774 

1775 

1776def drop_airflow_moved_tables(connection): 

1777 from airflow.models.base import Base 

1778 from airflow.settings import AIRFLOW_MOVED_TABLE_PREFIX 

1779 

1780 tables = set(inspect(connection).get_table_names()) 

1781 to_delete = [Table(x, Base.metadata) for x in tables if x.startswith(AIRFLOW_MOVED_TABLE_PREFIX)] 

1782 for tbl in to_delete: 

1783 tbl.drop(settings.engine, checkfirst=False) 

1784 Base.metadata.remove(tbl) 

1785 

1786 

1787@provide_session 

1788def check(session: Session = NEW_SESSION): 

1789 """ 

1790 Check if the database works. 

1791 

1792 :param session: session of the sqlalchemy 

1793 """ 

1794 session.execute(text("select 1 as is_alive;")) 

1795 log.info("Connection successful.") 

1796 

1797 

1798@enum.unique 

1799class DBLocks(enum.IntEnum): 

1800 """ 

1801 Cross-db Identifiers for advisory global database locks. 

1802 

1803 Postgres uses int64 lock ids so we use the integer value, MySQL uses names, so we 

1804 call ``str()`, which is implemented using the ``_name_`` field. 

1805 """ 

1806 

1807 MIGRATIONS = enum.auto() 

1808 SCHEDULER_CRITICAL_SECTION = enum.auto() 

1809 

1810 def __str__(self): 

1811 return f"airflow_{self._name_}" 

1812 

1813 

1814@contextlib.contextmanager 

1815def create_global_lock( 

1816 session: Session, 

1817 lock: DBLocks, 

1818 lock_timeout: int = 1800, 

1819) -> Generator[None, None, None]: 

1820 """Contextmanager that will create and teardown a global db lock.""" 

1821 conn = session.get_bind().connect() 

1822 dialect = conn.dialect 

1823 try: 

1824 if dialect.name == "postgresql": 

1825 conn.execute(text("SET LOCK_TIMEOUT to :timeout"), {"timeout": lock_timeout}) 

1826 conn.execute(text("SELECT pg_advisory_lock(:id)"), {"id": lock.value}) 

1827 elif dialect.name == "mysql" and dialect.server_version_info >= (5, 6): 

1828 conn.execute(text("SELECT GET_LOCK(:id, :timeout)"), {"id": str(lock), "timeout": lock_timeout}) 

1829 

1830 yield 

1831 finally: 

1832 if dialect.name == "postgresql": 

1833 conn.execute(text("SET LOCK_TIMEOUT TO DEFAULT")) 

1834 (unlocked,) = conn.execute(text("SELECT pg_advisory_unlock(:id)"), {"id": lock.value}).fetchone() 

1835 if not unlocked: 

1836 raise RuntimeError("Error releasing DB lock!") 

1837 elif dialect.name == "mysql" and dialect.server_version_info >= (5, 6): 

1838 conn.execute(text("select RELEASE_LOCK(:id)"), {"id": str(lock)}) 

1839 

1840 

1841def compare_type(context, inspected_column, metadata_column, inspected_type, metadata_type): 

1842 """ 

1843 Compare types between ORM and DB . 

1844 

1845 return False if the metadata_type is the same as the inspected_type 

1846 or None to allow the default implementation to compare these 

1847 types. a return value of True means the two types do not 

1848 match and should result in a type change operation. 

1849 """ 

1850 if context.dialect.name == "mysql": 

1851 from sqlalchemy import String 

1852 from sqlalchemy.dialects import mysql 

1853 

1854 if isinstance(inspected_type, mysql.VARCHAR) and isinstance(metadata_type, String): 

1855 # This is a hack to get around MySQL VARCHAR collation 

1856 # not being possible to change from utf8_bin to utf8mb3_bin. 

1857 # We only make sure lengths are the same 

1858 if inspected_type.length != metadata_type.length: 

1859 return True 

1860 return False 

1861 return None 

1862 

1863 

1864def compare_server_default( 

1865 context, inspected_column, metadata_column, inspected_default, metadata_default, rendered_metadata_default 

1866): 

1867 """ 

1868 Compare server defaults between ORM and DB . 

1869 

1870 return True if the defaults are different, False if not, or None to allow the default implementation 

1871 to compare these defaults 

1872 

1873 In SQLite: task_instance.map_index & task_reschedule.map_index 

1874 are not comparing accurately. Sometimes they are equal, sometimes they are not. 

1875 Alembic warned that this feature has varied accuracy depending on backends. 

1876 See: (https://alembic.sqlalchemy.org/en/latest/api/runtime.html#alembic.runtime. 

1877 environment.EnvironmentContext.configure.params.compare_server_default) 

1878 """ 

1879 dialect_name = context.connection.dialect.name 

1880 if dialect_name in ["sqlite"]: 

1881 return False 

1882 if ( 

1883 dialect_name == "mysql" 

1884 and metadata_column.name == "pool_slots" 

1885 and metadata_column.table.name == "task_instance" 

1886 ): 

1887 # We removed server_default value in ORM to avoid expensive migration 

1888 # (it was removed in postgres DB in migration head 7b2661a43ba3 ). 

1889 # As a side note, server default value here was only actually needed for the migration 

1890 # where we added the column in the first place -- now that it exists and all 

1891 # existing rows are populated with a value this server default is never used. 

1892 return False 

1893 return None 

1894 

1895 

1896def get_sqla_model_classes(): 

1897 """ 

1898 Get all SQLAlchemy class mappers. 

1899 

1900 SQLAlchemy < 1.4 does not support registry.mappers so we use 

1901 try/except to handle it. 

1902 """ 

1903 from airflow.models.base import Base 

1904 

1905 try: 

1906 return [mapper.class_ for mapper in Base.registry.mappers] 

1907 except AttributeError: 

1908 return Base._decl_class_registry.values() 

1909 

1910 

1911def get_query_count(query_stmt: Select, *, session: Session) -> int: 

1912 """Get count of a query. 

1913 

1914 A SELECT COUNT() FROM is issued against the subquery built from the 

1915 given statement. The ORDER BY clause is stripped from the statement 

1916 since it's unnecessary for COUNT, and can impact query planning and 

1917 degrade performance. 

1918 

1919 :meta private: 

1920 """ 

1921 count_stmt = select(func.count()).select_from(query_stmt.order_by(None).subquery()) 

1922 return session.scalar(count_stmt) 

1923 

1924 

1925def check_query_exists(query_stmt: Select, *, session: Session) -> bool: 

1926 """Check whether there is at least one row matching a query. 

1927 

1928 A SELECT 1 FROM is issued against the subquery built from the given 

1929 statement. The ORDER BY clause is stripped from the statement since it's 

1930 unnecessary, and can impact query planning and degrade performance. 

1931 

1932 :meta private: 

1933 """ 

1934 count_stmt = select(literal(True)).select_from(query_stmt.order_by(None).subquery()) 

1935 return session.scalar(count_stmt) 

1936 

1937 

1938def exists_query(*where: ClauseElement, session: Session) -> bool: 

1939 """Check whether there is at least one row matching given clauses. 

1940 

1941 This does a SELECT 1 WHERE ... LIMIT 1 and check the result. 

1942 

1943 :meta private: 

1944 """ 

1945 stmt = select(literal(True)).where(*where).limit(1) 

1946 return session.scalar(stmt) is not None 

1947 

1948 

1949@attrs.define(slots=True) 

1950class LazySelectSequence(Sequence[T]): 

1951 """List-like interface to lazily access a database model query. 

1952 

1953 The intended use case is inside a task execution context, where we manage an 

1954 active SQLAlchemy session in the background. 

1955 

1956 This is an abstract base class. Each use case should subclass, and implement 

1957 the following static methods: 

1958 

1959 * ``_rebuild_select`` is called when a lazy sequence is unpickled. Since it 

1960 is not easy to pickle SQLAlchemy constructs, this class serializes the 

1961 SELECT statements into plain text to storage. This method is called on 

1962 deserialization to convert the textual clause back into an ORM SELECT. 

1963 * ``_process_row`` is called when an item is accessed. The lazy sequence 

1964 uses ``session.execute()`` to fetch rows from the database, and this 

1965 method should know how to process each row into a value. 

1966 

1967 :meta private: 

1968 """ 

1969 

1970 _select_asc: ClauseElement 

1971 _select_desc: ClauseElement 

1972 _session: Session = attrs.field(kw_only=True, factory=get_current_task_instance_session) 

1973 _len: int | None = attrs.field(init=False, default=None) 

1974 

1975 @classmethod 

1976 def from_select( 

1977 cls, 

1978 select: Select, 

1979 *, 

1980 order_by: Sequence[ClauseElement], 

1981 session: Session | None = None, 

1982 ) -> Self: 

1983 s1 = select 

1984 for col in order_by: 

1985 s1 = s1.order_by(col.asc()) 

1986 s2 = select 

1987 for col in order_by: 

1988 s2 = s2.order_by(col.desc()) 

1989 return cls(s1, s2, session=session or get_current_task_instance_session()) 

1990 

1991 @staticmethod 

1992 def _rebuild_select(stmt: TextClause) -> Select: 

1993 """Rebuild a textual statement into an ORM-configured SELECT statement. 

1994 

1995 This should do something like ``select(field).from_statement(stmt)`` to 

1996 reconfigure ORM information to the textual SQL statement. 

1997 """ 

1998 raise NotImplementedError 

1999 

2000 @staticmethod 

2001 def _process_row(row: Row) -> T: 

2002 """Process a SELECT-ed row into the end value.""" 

2003 raise NotImplementedError 

2004 

2005 def __repr__(self) -> str: 

2006 counter = "item" if (length := len(self)) == 1 else "items" 

2007 return f"LazySelectSequence([{length} {counter}])" 

2008 

2009 def __str__(self) -> str: 

2010 counter = "item" if (length := len(self)) == 1 else "items" 

2011 return f"LazySelectSequence([{length} {counter}])" 

2012 

2013 def __getstate__(self) -> Any: 

2014 # We don't want to go to the trouble of serializing SQLAlchemy objects. 

2015 # Converting the statement into a SQL string is the best we can get. 

2016 # The literal_binds compile argument inlines all the values into the SQL 

2017 # string to simplify cross-process commuinication as much as possible. 

2018 # Theoratically we can do the same for count(), but I think it should be 

2019 # performant enough to calculate only that eagerly. 

2020 s1 = str(self._select_asc.compile(self._session.get_bind(), compile_kwargs={"literal_binds": True})) 

2021 s2 = str(self._select_desc.compile(self._session.get_bind(), compile_kwargs={"literal_binds": True})) 

2022 return (s1, s2, len(self)) 

2023 

2024 def __setstate__(self, state: Any) -> None: 

2025 s1, s2, self._len = state 

2026 self._select_asc = self._rebuild_select(text(s1)) 

2027 self._select_desc = self._rebuild_select(text(s2)) 

2028 self._session = get_current_task_instance_session() 

2029 

2030 def __bool__(self) -> bool: 

2031 return check_query_exists(self._select_asc, session=self._session) 

2032 

2033 def __eq__(self, other: Any) -> bool: 

2034 if not isinstance(other, collections.abc.Sequence): 

2035 return NotImplemented 

2036 z = itertools.zip_longest(iter(self), iter(other), fillvalue=object()) 

2037 return all(x == y for x, y in z) 

2038 

2039 def __reversed__(self) -> Iterator[T]: 

2040 return iter(self._process_row(r) for r in self._session.execute(self._select_desc)) 

2041 

2042 def __iter__(self) -> Iterator[T]: 

2043 return iter(self._process_row(r) for r in self._session.execute(self._select_asc)) 

2044 

2045 def __len__(self) -> int: 

2046 if self._len is None: 

2047 self._len = get_query_count(self._select_asc, session=self._session) 

2048 return self._len 

2049 

2050 @overload 

2051 def __getitem__(self, key: int) -> T: ... 

2052 

2053 @overload 

2054 def __getitem__(self, key: slice) -> Sequence[T]: ... 

2055 

2056 def __getitem__(self, key: int | slice) -> T | Sequence[T]: 

2057 if isinstance(key, int): 

2058 if key >= 0: 

2059 stmt = self._select_asc.offset(key) 

2060 else: 

2061 stmt = self._select_desc.offset(-1 - key) 

2062 if (row := self._session.execute(stmt.limit(1)).one_or_none()) is None: 

2063 raise IndexError(key) 

2064 return self._process_row(row) 

2065 elif isinstance(key, slice): 

2066 # This implements the slicing syntax. We want to optimize negative 

2067 # slicing (e.g. seq[-10:]) by not doing an additional COUNT query 

2068 # if possible. We can do this unless the start and stop have 

2069 # different signs (i.e. one is positive and another negative). 

2070 start, stop, reverse = _coerce_slice(key) 

2071 if start >= 0: 

2072 if stop is None: 

2073 stmt = self._select_asc.offset(start) 

2074 elif stop >= 0: 

2075 stmt = self._select_asc.slice(start, stop) 

2076 else: 

2077 stmt = self._select_asc.slice(start, len(self) + stop) 

2078 rows = [self._process_row(row) for row in self._session.execute(stmt)] 

2079 if reverse: 

2080 rows.reverse() 

2081 else: 

2082 if stop is None: 

2083 stmt = self._select_desc.limit(-start) 

2084 elif stop < 0: 

2085 stmt = self._select_desc.slice(-stop, -start) 

2086 else: 

2087 stmt = self._select_desc.slice(len(self) - stop, -start) 

2088 rows = [self._process_row(row) for row in self._session.execute(stmt)] 

2089 if not reverse: 

2090 rows.reverse() 

2091 return rows 

2092 raise TypeError(f"Sequence indices must be integers or slices, not {type(key).__name__}") 

2093 

2094 

2095def _coerce_index(value: Any) -> int | None: 

2096 """Check slice attribute's type and convert it to int. 

2097 

2098 See CPython documentation on this: 

2099 https://docs.python.org/3/reference/datamodel.html#object.__index__ 

2100 """ 

2101 if value is None or isinstance(value, int): 

2102 return value 

2103 if (index := getattr(value, "__index__", None)) is not None: 

2104 return index() 

2105 raise TypeError("slice indices must be integers or None or have an __index__ method") 

2106 

2107 

2108def _coerce_slice(key: slice) -> tuple[int, int | None, bool]: 

2109 """Check slice content and convert it for SQL. 

2110 

2111 See CPython documentation on this: 

2112 https://docs.python.org/3/reference/datamodel.html#slice-objects 

2113 """ 

2114 if key.step is None or key.step == 1: 

2115 reverse = False 

2116 elif key.step == -1: 

2117 reverse = True 

2118 else: 

2119 raise ValueError("non-trivial slice step not supported") 

2120 return _coerce_index(key.start) or 0, _coerce_index(key.stop), reverse