Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/airflow/utils/db.py: 16%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18from __future__ import annotations
20import collections.abc
21import contextlib
22import enum
23import itertools
24import json
25import logging
26import os
27import sys
28import time
29import warnings
30from dataclasses import dataclass
31from tempfile import gettempdir
32from typing import (
33 TYPE_CHECKING,
34 Any,
35 Callable,
36 Generator,
37 Iterable,
38 Iterator,
39 Protocol,
40 Sequence,
41 TypeVar,
42 overload,
43)
45import attrs
46from sqlalchemy import (
47 Table,
48 and_,
49 column,
50 delete,
51 exc,
52 func,
53 inspect,
54 literal,
55 or_,
56 select,
57 table,
58 text,
59 tuple_,
60)
62import airflow
63from airflow import settings
64from airflow.configuration import conf
65from airflow.exceptions import AirflowException
66from airflow.models import import_all_models
67from airflow.utils import helpers
69# TODO: remove create_session once we decide to break backward compatibility
70from airflow.utils.session import NEW_SESSION, create_session, provide_session # noqa: F401
71from airflow.utils.task_instance_session import get_current_task_instance_session
73if TYPE_CHECKING:
74 from alembic.runtime.environment import EnvironmentContext
75 from alembic.script import ScriptDirectory
76 from sqlalchemy.engine import Row
77 from sqlalchemy.orm import Query, Session
78 from sqlalchemy.sql.elements import ClauseElement, TextClause
79 from sqlalchemy.sql.selectable import Select
81 from airflow.models.connection import Connection
82 from airflow.typing_compat import Self
84 # TODO: Import this from sqlalchemy.orm instead when switching to SQLA 2.
85 # https://docs.sqlalchemy.org/en/20/orm/mapping_api.html#sqlalchemy.orm.MappedClassProtocol
86 class MappedClassProtocol(Protocol):
87 """Protocol for SQLALchemy model base."""
89 __tablename__: str
92T = TypeVar("T")
94log = logging.getLogger(__name__)
96_REVISION_HEADS_MAP = {
97 "2.0.0": "e959f08ac86c",
98 "2.0.1": "82b7c48c147f",
99 "2.0.2": "2e42bb497a22",
100 "2.1.0": "a13f7613ad25",
101 "2.1.3": "97cdd93827b8",
102 "2.1.4": "ccde3e26fe78",
103 "2.2.0": "7b2661a43ba3",
104 "2.2.3": "be2bfac3da23",
105 "2.2.4": "587bdf053233",
106 "2.3.0": "b1b348e02d07",
107 "2.3.1": "1de7bc13c950",
108 "2.3.2": "3c94c427fdf6",
109 "2.3.3": "f5fcbda3e651",
110 "2.4.0": "ecb43d2a1842",
111 "2.4.2": "b0d31815b5a6",
112 "2.4.3": "e07f49787c9d",
113 "2.5.0": "290244fb8b83",
114 "2.6.0": "98ae134e6fff",
115 "2.6.2": "c804e5c76e3e",
116 "2.7.0": "405de8318b3a",
117 "2.8.0": "10b52ebd31f7",
118 "2.8.1": "88344c1d9134",
119 "2.9.0": "1949afb29106",
120 "2.9.2": "0fd0c178cbe8",
121 "2.10.0": "c4602ba06b4b",
122}
125def _format_airflow_moved_table_name(source_table, version, category):
126 return "__".join([settings.AIRFLOW_MOVED_TABLE_PREFIX, version.replace(".", "_"), category, source_table])
129@provide_session
130def merge_conn(conn: Connection, session: Session = NEW_SESSION):
131 """Add new Connection."""
132 if not session.scalar(select(1).where(conn.__class__.conn_id == conn.conn_id)):
133 session.add(conn)
134 session.commit()
137@provide_session
138def add_default_pool_if_not_exists(session: Session = NEW_SESSION):
139 """Add default pool if it does not exist."""
140 from airflow.models.pool import Pool
142 if not Pool.get_pool(Pool.DEFAULT_POOL_NAME, session=session):
143 default_pool = Pool(
144 pool=Pool.DEFAULT_POOL_NAME,
145 slots=conf.getint(section="core", key="default_pool_task_slot_count"),
146 description="Default pool",
147 include_deferred=False,
148 )
149 session.add(default_pool)
150 session.commit()
153@provide_session
154def create_default_connections(session: Session = NEW_SESSION):
155 """Create default Airflow connections."""
156 from airflow.models.connection import Connection
158 merge_conn(
159 Connection(
160 conn_id="airflow_db",
161 conn_type="mysql",
162 host="mysql",
163 login="root",
164 password="",
165 schema="airflow",
166 ),
167 session,
168 )
169 merge_conn(
170 Connection(
171 conn_id="athena_default",
172 conn_type="athena",
173 ),
174 session,
175 )
176 merge_conn(
177 Connection(
178 conn_id="aws_default",
179 conn_type="aws",
180 ),
181 session,
182 )
183 merge_conn(
184 Connection(
185 conn_id="azure_batch_default",
186 conn_type="azure_batch",
187 login="<ACCOUNT_NAME>",
188 password="",
189 extra="""{"account_url": "<ACCOUNT_URL>"}""",
190 )
191 )
192 merge_conn(
193 Connection(
194 conn_id="azure_cosmos_default",
195 conn_type="azure_cosmos",
196 extra='{"database_name": "<DATABASE_NAME>", "collection_name": "<COLLECTION_NAME>" }',
197 ),
198 session,
199 )
200 merge_conn(
201 Connection(
202 conn_id="azure_data_explorer_default",
203 conn_type="azure_data_explorer",
204 host="https://<CLUSTER>.kusto.windows.net",
205 extra="""{"auth_method": "<AAD_APP | AAD_APP_CERT | AAD_CREDS | AAD_DEVICE>",
206 "tenant": "<TENANT ID>", "certificate": "<APPLICATION PEM CERTIFICATE>",
207 "thumbprint": "<APPLICATION CERTIFICATE THUMBPRINT>"}""",
208 ),
209 session,
210 )
211 merge_conn(
212 Connection(
213 conn_id="azure_data_lake_default",
214 conn_type="azure_data_lake",
215 extra='{"tenant": "<TENANT>", "account_name": "<ACCOUNTNAME>" }',
216 ),
217 session,
218 )
219 merge_conn(
220 Connection(
221 conn_id="azure_default",
222 conn_type="azure",
223 ),
224 session,
225 )
226 merge_conn(
227 Connection(
228 conn_id="cassandra_default",
229 conn_type="cassandra",
230 host="cassandra",
231 port=9042,
232 ),
233 session,
234 )
235 merge_conn(
236 Connection(
237 conn_id="databricks_default",
238 conn_type="databricks",
239 host="localhost",
240 ),
241 session,
242 )
243 merge_conn(
244 Connection(
245 conn_id="dingding_default",
246 conn_type="http",
247 host="",
248 password="",
249 ),
250 session,
251 )
252 merge_conn(
253 Connection(
254 conn_id="drill_default",
255 conn_type="drill",
256 host="localhost",
257 port=8047,
258 extra='{"dialect_driver": "drill+sadrill", "storage_plugin": "dfs"}',
259 ),
260 session,
261 )
262 merge_conn(
263 Connection(
264 conn_id="druid_broker_default",
265 conn_type="druid",
266 host="druid-broker",
267 port=8082,
268 extra='{"endpoint": "druid/v2/sql"}',
269 ),
270 session,
271 )
272 merge_conn(
273 Connection(
274 conn_id="druid_ingest_default",
275 conn_type="druid",
276 host="druid-overlord",
277 port=8081,
278 extra='{"endpoint": "druid/indexer/v1/task"}',
279 ),
280 session,
281 )
282 merge_conn(
283 Connection(
284 conn_id="elasticsearch_default",
285 conn_type="elasticsearch",
286 host="localhost",
287 schema="http",
288 port=9200,
289 ),
290 session,
291 )
292 merge_conn(
293 Connection(
294 conn_id="emr_default",
295 conn_type="emr",
296 extra="""
297 { "Name": "default_job_flow_name",
298 "LogUri": "s3://my-emr-log-bucket/default_job_flow_location",
299 "ReleaseLabel": "emr-4.6.0",
300 "Instances": {
301 "Ec2KeyName": "mykey",
302 "Ec2SubnetId": "somesubnet",
303 "InstanceGroups": [
304 {
305 "Name": "Master nodes",
306 "Market": "ON_DEMAND",
307 "InstanceRole": "MASTER",
308 "InstanceType": "r3.2xlarge",
309 "InstanceCount": 1
310 },
311 {
312 "Name": "Core nodes",
313 "Market": "ON_DEMAND",
314 "InstanceRole": "CORE",
315 "InstanceType": "r3.2xlarge",
316 "InstanceCount": 1
317 }
318 ],
319 "TerminationProtected": false,
320 "KeepJobFlowAliveWhenNoSteps": false
321 },
322 "Applications":[
323 { "Name": "Spark" }
324 ],
325 "VisibleToAllUsers": true,
326 "JobFlowRole": "EMR_EC2_DefaultRole",
327 "ServiceRole": "EMR_DefaultRole",
328 "Tags": [
329 {
330 "Key": "app",
331 "Value": "analytics"
332 },
333 {
334 "Key": "environment",
335 "Value": "development"
336 }
337 ]
338 }
339 """,
340 ),
341 session,
342 )
343 merge_conn(
344 Connection(
345 conn_id="facebook_default",
346 conn_type="facebook_social",
347 extra="""
348 { "account_id": "<AD_ACCOUNT_ID>",
349 "app_id": "<FACEBOOK_APP_ID>",
350 "app_secret": "<FACEBOOK_APP_SECRET>",
351 "access_token": "<FACEBOOK_AD_ACCESS_TOKEN>"
352 }
353 """,
354 ),
355 session,
356 )
357 merge_conn(
358 Connection(
359 conn_id="fs_default",
360 conn_type="fs",
361 extra='{"path": "/"}',
362 ),
363 session,
364 )
365 merge_conn(
366 Connection(
367 conn_id="ftp_default",
368 conn_type="ftp",
369 host="localhost",
370 port=21,
371 login="airflow",
372 password="airflow",
373 extra='{"key_file": "~/.ssh/id_rsa", "no_host_key_check": true}',
374 ),
375 session,
376 )
377 merge_conn(
378 Connection(
379 conn_id="google_cloud_default",
380 conn_type="google_cloud_platform",
381 schema="default",
382 ),
383 session,
384 )
385 merge_conn(
386 Connection(
387 conn_id="hive_cli_default",
388 conn_type="hive_cli",
389 port=10000,
390 host="localhost",
391 extra='{"use_beeline": true, "auth": ""}',
392 schema="default",
393 ),
394 session,
395 )
396 merge_conn(
397 Connection(
398 conn_id="hiveserver2_default",
399 conn_type="hiveserver2",
400 host="localhost",
401 schema="default",
402 port=10000,
403 ),
404 session,
405 )
406 merge_conn(
407 Connection(
408 conn_id="http_default",
409 conn_type="http",
410 host="https://www.httpbin.org/",
411 ),
412 session,
413 )
414 merge_conn(
415 Connection(
416 conn_id="iceberg_default",
417 conn_type="iceberg",
418 host="https://api.iceberg.io/ws/v1",
419 ),
420 session,
421 )
422 merge_conn(Connection(conn_id="impala_default", conn_type="impala", host="localhost", port=21050))
423 merge_conn(
424 Connection(
425 conn_id="kafka_default",
426 conn_type="kafka",
427 extra=json.dumps({"bootstrap.servers": "broker:29092", "group.id": "my-group"}),
428 ),
429 session,
430 )
431 merge_conn(
432 Connection(
433 conn_id="kubernetes_default",
434 conn_type="kubernetes",
435 ),
436 session,
437 )
438 merge_conn(
439 Connection(
440 conn_id="kylin_default",
441 conn_type="kylin",
442 host="localhost",
443 port=7070,
444 login="ADMIN",
445 password="KYLIN",
446 ),
447 session,
448 )
449 merge_conn(
450 Connection(
451 conn_id="leveldb_default",
452 conn_type="leveldb",
453 host="localhost",
454 ),
455 session,
456 )
457 merge_conn(Connection(conn_id="livy_default", conn_type="livy", host="livy", port=8998), session)
458 merge_conn(
459 Connection(
460 conn_id="local_mysql",
461 conn_type="mysql",
462 host="localhost",
463 login="airflow",
464 password="airflow",
465 schema="airflow",
466 ),
467 session,
468 )
469 merge_conn(
470 Connection(
471 conn_id="metastore_default",
472 conn_type="hive_metastore",
473 host="localhost",
474 extra='{"authMechanism": "PLAIN"}',
475 port=9083,
476 ),
477 session,
478 )
479 merge_conn(Connection(conn_id="mongo_default", conn_type="mongo", host="mongo", port=27017), session)
480 merge_conn(
481 Connection(
482 conn_id="mssql_default",
483 conn_type="mssql",
484 host="localhost",
485 port=1433,
486 ),
487 session,
488 )
489 merge_conn(
490 Connection(
491 conn_id="mysql_default",
492 conn_type="mysql",
493 login="root",
494 schema="airflow",
495 host="mysql",
496 ),
497 session,
498 )
499 merge_conn(
500 Connection(
501 conn_id="opsgenie_default",
502 conn_type="http",
503 host="",
504 password="",
505 ),
506 session,
507 )
508 merge_conn(
509 Connection(
510 conn_id="oracle_default",
511 conn_type="oracle",
512 host="localhost",
513 login="root",
514 password="password",
515 schema="schema",
516 port=1521,
517 ),
518 session,
519 )
520 merge_conn(
521 Connection(
522 conn_id="oss_default",
523 conn_type="oss",
524 extra="""{
525 "auth_type": "AK",
526 "access_key_id": "<ACCESS_KEY_ID>",
527 "access_key_secret": "<ACCESS_KEY_SECRET>",
528 "region": "<YOUR_OSS_REGION>"}
529 """,
530 ),
531 session,
532 )
533 merge_conn(
534 Connection(
535 conn_id="pig_cli_default",
536 conn_type="pig_cli",
537 schema="default",
538 ),
539 session,
540 )
541 merge_conn(
542 Connection(
543 conn_id="pinot_admin_default",
544 conn_type="pinot",
545 host="localhost",
546 port=9000,
547 ),
548 session,
549 )
550 merge_conn(
551 Connection(
552 conn_id="pinot_broker_default",
553 conn_type="pinot",
554 host="localhost",
555 port=9000,
556 extra='{"endpoint": "/query", "schema": "http"}',
557 ),
558 session,
559 )
560 merge_conn(
561 Connection(
562 conn_id="postgres_default",
563 conn_type="postgres",
564 login="postgres",
565 password="airflow",
566 schema="airflow",
567 host="postgres",
568 ),
569 session,
570 )
571 merge_conn(
572 Connection(
573 conn_id="presto_default",
574 conn_type="presto",
575 host="localhost",
576 schema="hive",
577 port=3400,
578 ),
579 session,
580 )
581 merge_conn(
582 Connection(
583 conn_id="qdrant_default",
584 conn_type="qdrant",
585 host="qdrant",
586 port=6333,
587 ),
588 session,
589 )
590 merge_conn(
591 Connection(
592 conn_id="redis_default",
593 conn_type="redis",
594 host="redis",
595 port=6379,
596 extra='{"db": 0}',
597 ),
598 session,
599 )
600 merge_conn(
601 Connection(
602 conn_id="redshift_default",
603 conn_type="redshift",
604 extra="""{
605 "iam": true,
606 "cluster_identifier": "<REDSHIFT_CLUSTER_IDENTIFIER>",
607 "port": 5439,
608 "profile": "default",
609 "db_user": "awsuser",
610 "database": "dev",
611 "region": ""
612}""",
613 ),
614 session,
615 )
616 merge_conn(
617 Connection(
618 conn_id="salesforce_default",
619 conn_type="salesforce",
620 login="username",
621 password="password",
622 extra='{"security_token": "security_token"}',
623 ),
624 session,
625 )
626 merge_conn(
627 Connection(
628 conn_id="segment_default",
629 conn_type="segment",
630 extra='{"write_key": "my-segment-write-key"}',
631 ),
632 session,
633 )
634 merge_conn(
635 Connection(
636 conn_id="sftp_default",
637 conn_type="sftp",
638 host="localhost",
639 port=22,
640 login="airflow",
641 extra='{"key_file": "~/.ssh/id_rsa", "no_host_key_check": true}',
642 ),
643 session,
644 )
645 merge_conn(
646 Connection(
647 conn_id="spark_default",
648 conn_type="spark",
649 host="yarn",
650 extra='{"queue": "root.default"}',
651 ),
652 session,
653 )
654 merge_conn(
655 Connection(
656 conn_id="sqlite_default",
657 conn_type="sqlite",
658 host=os.path.join(gettempdir(), "sqlite_default.db"),
659 ),
660 session,
661 )
662 merge_conn(
663 Connection(
664 conn_id="ssh_default",
665 conn_type="ssh",
666 host="localhost",
667 ),
668 session,
669 )
670 merge_conn(
671 Connection(
672 conn_id="tableau_default",
673 conn_type="tableau",
674 host="https://tableau.server.url",
675 login="user",
676 password="password",
677 extra='{"site_id": "my_site"}',
678 ),
679 session,
680 )
681 merge_conn(
682 Connection(
683 conn_id="tabular_default",
684 conn_type="tabular",
685 host="https://api.tabulardata.io/ws/v1",
686 ),
687 session,
688 )
689 merge_conn(
690 Connection(
691 conn_id="teradata_default",
692 conn_type="teradata",
693 host="localhost",
694 login="user",
695 password="password",
696 schema="schema",
697 ),
698 session,
699 )
700 merge_conn(
701 Connection(
702 conn_id="trino_default",
703 conn_type="trino",
704 host="localhost",
705 schema="hive",
706 port=3400,
707 ),
708 session,
709 )
710 merge_conn(
711 Connection(
712 conn_id="vertica_default",
713 conn_type="vertica",
714 host="localhost",
715 port=5433,
716 ),
717 session,
718 )
719 merge_conn(
720 Connection(
721 conn_id="wasb_default",
722 conn_type="wasb",
723 extra='{"sas_token": null}',
724 ),
725 session,
726 )
727 merge_conn(
728 Connection(
729 conn_id="webhdfs_default",
730 conn_type="hdfs",
731 host="localhost",
732 port=50070,
733 ),
734 session,
735 )
736 merge_conn(
737 Connection(
738 conn_id="yandexcloud_default",
739 conn_type="yandexcloud",
740 schema="default",
741 ),
742 session,
743 )
746def _get_flask_db(sql_database_uri):
747 from flask import Flask
748 from flask_sqlalchemy import SQLAlchemy
750 from airflow.www.session import AirflowDatabaseSessionInterface
752 flask_app = Flask(__name__)
753 flask_app.config["SQLALCHEMY_DATABASE_URI"] = sql_database_uri
754 flask_app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] = False
755 db = SQLAlchemy(flask_app)
756 AirflowDatabaseSessionInterface(app=flask_app, db=db, table="session", key_prefix="")
757 return db
760def _create_db_from_orm(session):
761 from alembic import command
763 from airflow.models.base import Base
764 from airflow.providers.fab.auth_manager.models import Model
766 def _create_flask_session_tbl(sql_database_uri):
767 db = _get_flask_db(sql_database_uri)
768 db.create_all()
770 with create_global_lock(session=session, lock=DBLocks.MIGRATIONS):
771 engine = session.get_bind().engine
772 Base.metadata.create_all(engine)
773 Model.metadata.create_all(engine)
774 _create_flask_session_tbl(engine.url)
775 # stamp the migration head
776 config = _get_alembic_config()
777 command.stamp(config, "head")
780@provide_session
781def initdb(session: Session = NEW_SESSION, load_connections: bool = True, use_migration_files: bool = False):
782 """Initialize Airflow database."""
783 import_all_models()
785 db_exists = _get_current_revision(session)
786 if db_exists or use_migration_files:
787 upgradedb(session=session, use_migration_files=use_migration_files)
788 else:
789 _create_db_from_orm(session=session)
790 if conf.getboolean("database", "LOAD_DEFAULT_CONNECTIONS") and load_connections:
791 create_default_connections(session=session)
792 # Add default pool & sync log_template
793 add_default_pool_if_not_exists(session=session)
794 synchronize_log_template(session=session)
797def _get_alembic_config():
798 from alembic.config import Config
800 package_dir = os.path.dirname(airflow.__file__)
801 directory = os.path.join(package_dir, "migrations")
802 alembic_file = conf.get("database", "alembic_ini_file_path")
803 if os.path.isabs(alembic_file):
804 config = Config(alembic_file)
805 else:
806 config = Config(os.path.join(package_dir, alembic_file))
807 config.set_main_option("script_location", directory.replace("%", "%%"))
808 config.set_main_option("sqlalchemy.url", settings.SQL_ALCHEMY_CONN.replace("%", "%%"))
809 return config
812def _get_script_object(config=None) -> ScriptDirectory:
813 from alembic.script import ScriptDirectory
815 if not config:
816 config = _get_alembic_config()
817 return ScriptDirectory.from_config(config)
820def _get_current_revision(session):
821 from alembic.migration import MigrationContext
823 conn = session.connection()
825 migration_ctx = MigrationContext.configure(conn)
827 return migration_ctx.get_current_revision()
830def check_migrations(timeout):
831 """
832 Wait for all airflow migrations to complete.
834 :param timeout: Timeout for the migration in seconds
835 :return: None
836 """
837 timeout = timeout or 1 # run the loop at least 1
838 with _configured_alembic_environment() as env:
839 context = env.get_context()
840 source_heads = None
841 db_heads = None
842 for ticker in range(timeout):
843 source_heads = set(env.script.get_heads())
844 db_heads = set(context.get_current_heads())
845 if source_heads == db_heads:
846 return
847 time.sleep(1)
848 log.info("Waiting for migrations... %s second(s)", ticker)
849 raise TimeoutError(
850 f"There are still unapplied migrations after {timeout} seconds. Migration"
851 f"Head(s) in DB: {db_heads} | Migration Head(s) in Source Code: {source_heads}"
852 )
855@contextlib.contextmanager
856def _configured_alembic_environment() -> Generator[EnvironmentContext, None, None]:
857 from alembic.runtime.environment import EnvironmentContext
859 config = _get_alembic_config()
860 script = _get_script_object(config)
862 with EnvironmentContext(
863 config,
864 script,
865 ) as env, settings.engine.connect() as connection:
866 alembic_logger = logging.getLogger("alembic")
867 level = alembic_logger.level
868 alembic_logger.setLevel(logging.WARNING)
869 env.configure(connection)
870 alembic_logger.setLevel(level)
872 yield env
875def check_and_run_migrations():
876 """Check and run migrations if necessary. Only use in a tty."""
877 with _configured_alembic_environment() as env:
878 context = env.get_context()
879 source_heads = set(env.script.get_heads())
880 db_heads = set(context.get_current_heads())
881 db_command = None
882 command_name = None
883 verb = None
884 if len(db_heads) < 1:
885 db_command = initdb
886 command_name = "init"
887 verb = "initialize"
888 elif source_heads != db_heads:
889 db_command = upgradedb
890 command_name = "upgrade"
891 verb = "upgrade"
893 if sys.stdout.isatty() and verb:
894 print()
895 question = f"Please confirm database {verb} (or wait 4 seconds to skip it). Are you sure? [y/N]"
896 try:
897 answer = helpers.prompt_with_timeout(question, timeout=4, default=False)
898 if answer:
899 try:
900 db_command()
901 print(f"DB {verb} done")
902 except Exception as error:
903 from airflow.version import version
905 print(error)
906 print(
907 "You still have unapplied migrations. "
908 f"You may need to {verb} the database by running `airflow db {command_name}`. ",
909 f"Make sure the command is run using Airflow version {version}.",
910 file=sys.stderr,
911 )
912 sys.exit(1)
913 except AirflowException:
914 pass
915 elif source_heads != db_heads:
916 from airflow.version import version
918 print(
919 f"ERROR: You need to {verb} the database. Please run `airflow db {command_name}`. "
920 f"Make sure the command is run using Airflow version {version}.",
921 file=sys.stderr,
922 )
923 sys.exit(1)
926def _reserialize_dags(*, session: Session) -> None:
927 from airflow.models.dagbag import DagBag
928 from airflow.models.serialized_dag import SerializedDagModel
930 session.execute(delete(SerializedDagModel).execution_options(synchronize_session=False))
931 dagbag = DagBag(collect_dags=False)
932 dagbag.collect_dags(only_if_updated=False)
933 dagbag.sync_to_db(session=session)
936@provide_session
937def synchronize_log_template(*, session: Session = NEW_SESSION) -> None:
938 """Synchronize log template configs with table.
940 This checks if the last row fully matches the current config values, and
941 insert a new row if not.
942 """
943 # NOTE: SELECT queries in this function are INTENTIONALLY written with the
944 # SQL builder style, not the ORM query API. This avoids configuring the ORM
945 # unless we need to insert something, speeding up CLI in general.
947 from airflow.models.tasklog import LogTemplate
949 metadata = reflect_tables([LogTemplate], session)
950 log_template_table: Table | None = metadata.tables.get(LogTemplate.__tablename__)
952 if log_template_table is None:
953 log.info("Log template table does not exist (added in 2.3.0); skipping log template sync.")
954 return
956 filename = conf.get("logging", "log_filename_template")
957 elasticsearch_id = conf.get("elasticsearch", "log_id_template")
959 stored = session.execute(
960 select(
961 log_template_table.c.filename,
962 log_template_table.c.elasticsearch_id,
963 )
964 .order_by(log_template_table.c.id.desc())
965 .limit(1)
966 ).first()
968 # If we have an empty table, and the default values exist, we will seed the
969 # table with values from pre 2.3.0, so old logs will still be retrievable.
970 if not stored:
971 is_default_log_id = elasticsearch_id == conf.get_default_value("elasticsearch", "log_id_template")
972 is_default_filename = filename == conf.get_default_value("logging", "log_filename_template")
973 if is_default_log_id and is_default_filename:
974 session.add(
975 LogTemplate(
976 filename="{{ ti.dag_id }}/{{ ti.task_id }}/{{ ts }}/{{ try_number }}.log",
977 elasticsearch_id="{dag_id}-{task_id}-{execution_date}-{try_number}",
978 )
979 )
981 # Before checking if the _current_ value exists, we need to check if the old config value we upgraded in
982 # place exists!
983 pre_upgrade_filename = conf.upgraded_values.get(("logging", "log_filename_template"), filename)
984 pre_upgrade_elasticsearch_id = conf.upgraded_values.get(
985 ("elasticsearch", "log_id_template"), elasticsearch_id
986 )
987 if pre_upgrade_filename != filename or pre_upgrade_elasticsearch_id != elasticsearch_id:
988 # The previous non-upgraded value likely won't be the _latest_ value (as after we've recorded the
989 # recorded the upgraded value it will be second-to-newest), so we'll have to just search which is okay
990 # as this is a table with a tiny number of rows
991 row = session.execute(
992 select(log_template_table.c.id)
993 .where(
994 or_(
995 log_template_table.c.filename == pre_upgrade_filename,
996 log_template_table.c.elasticsearch_id == pre_upgrade_elasticsearch_id,
997 )
998 )
999 .order_by(log_template_table.c.id.desc())
1000 .limit(1)
1001 ).first()
1002 if not row:
1003 session.add(
1004 LogTemplate(filename=pre_upgrade_filename, elasticsearch_id=pre_upgrade_elasticsearch_id)
1005 )
1007 if not stored or stored.filename != filename or stored.elasticsearch_id != elasticsearch_id:
1008 session.add(LogTemplate(filename=filename, elasticsearch_id=elasticsearch_id))
1011def check_conn_id_duplicates(session: Session) -> Iterable[str]:
1012 """
1013 Check unique conn_id in connection table.
1015 :param session: session of the sqlalchemy
1016 """
1017 from airflow.models.connection import Connection
1019 try:
1020 dups = session.scalars(
1021 select(Connection.conn_id).group_by(Connection.conn_id).having(func.count() > 1)
1022 ).all()
1023 except (exc.OperationalError, exc.ProgrammingError):
1024 # fallback if tables hasn't been created yet
1025 session.rollback()
1026 return
1027 if dups:
1028 yield (
1029 "Seems you have non unique conn_id in connection table.\n"
1030 "You have to manage those duplicate connections "
1031 "before upgrading the database.\n"
1032 f"Duplicated conn_id: {dups}"
1033 )
1036def check_username_duplicates(session: Session) -> Iterable[str]:
1037 """
1038 Check unique username in User & RegisterUser table.
1040 :param session: session of the sqlalchemy
1041 :rtype: str
1042 """
1043 from airflow.providers.fab.auth_manager.models import RegisterUser, User
1045 for model in [User, RegisterUser]:
1046 dups = []
1047 try:
1048 dups = session.execute(
1049 select(model.username) # type: ignore[attr-defined]
1050 .group_by(model.username) # type: ignore[attr-defined]
1051 .having(func.count() > 1)
1052 ).all()
1053 except (exc.OperationalError, exc.ProgrammingError):
1054 # fallback if tables hasn't been created yet
1055 session.rollback()
1056 if dups:
1057 yield (
1058 f"Seems you have mixed case usernames in {model.__table__.name} table.\n" # type: ignore
1059 "You have to rename or delete those mixed case usernames "
1060 "before upgrading the database.\n"
1061 f"usernames with mixed cases: {[dup.username for dup in dups]}"
1062 )
1065def reflect_tables(tables: list[MappedClassProtocol | str] | None, session):
1066 """
1067 When running checks prior to upgrades, we use reflection to determine current state of the database.
1069 This function gets the current state of each table in the set of models
1070 provided and returns a SqlAlchemy metadata object containing them.
1071 """
1072 import sqlalchemy.schema
1074 bind = session.bind
1075 metadata = sqlalchemy.schema.MetaData()
1077 if tables is None:
1078 metadata.reflect(bind=bind, resolve_fks=False)
1079 else:
1080 for tbl in tables:
1081 try:
1082 table_name = tbl if isinstance(tbl, str) else tbl.__tablename__
1083 metadata.reflect(bind=bind, only=[table_name], extend_existing=True, resolve_fks=False)
1084 except exc.InvalidRequestError:
1085 continue
1086 return metadata
1089def check_table_for_duplicates(
1090 *, session: Session, table_name: str, uniqueness: list[str], version: str
1091) -> Iterable[str]:
1092 """
1093 Check table for duplicates, given a list of columns which define the uniqueness of the table.
1095 Usage example:
1097 .. code-block:: python
1099 def check_task_fail_for_duplicates(session):
1100 from airflow.models.taskfail import TaskFail
1102 metadata = reflect_tables([TaskFail], session)
1103 task_fail = metadata.tables.get(TaskFail.__tablename__) # type: ignore
1104 if task_fail is None: # table not there
1105 return
1106 if "run_id" in task_fail.columns: # upgrade already applied
1107 return
1108 yield from check_table_for_duplicates(
1109 table_name=task_fail.name,
1110 uniqueness=["dag_id", "task_id", "execution_date"],
1111 session=session,
1112 version="2.3",
1113 )
1115 :param table_name: table name to check
1116 :param uniqueness: uniqueness constraint to evaluate against
1117 :param session: session of the sqlalchemy
1118 """
1119 minimal_table_obj = table(table_name, *(column(x) for x in uniqueness))
1120 try:
1121 subquery = session.execute(
1122 select(minimal_table_obj, func.count().label("dupe_count"))
1123 .group_by(*(text(x) for x in uniqueness))
1124 .having(func.count() > text("1"))
1125 .subquery()
1126 )
1127 dupe_count = session.scalar(select(func.sum(subquery.c.dupe_count)))
1128 if not dupe_count:
1129 # there are no duplicates; nothing to do.
1130 return
1132 log.warning("Found %s duplicates in table %s. Will attempt to move them.", dupe_count, table_name)
1134 metadata = reflect_tables(tables=[table_name], session=session)
1135 if table_name not in metadata.tables:
1136 yield f"Table {table_name} does not exist in the database."
1138 # We can't use the model here since it may differ from the db state due to
1139 # this function is run prior to migration. Use the reflected table instead.
1140 table_obj = metadata.tables[table_name]
1142 _move_duplicate_data_to_new_table(
1143 session=session,
1144 source_table=table_obj,
1145 subquery=subquery,
1146 uniqueness=uniqueness,
1147 target_table_name=_format_airflow_moved_table_name(table_name, version, "duplicates"),
1148 )
1149 except (exc.OperationalError, exc.ProgrammingError):
1150 # fallback if `table_name` hasn't been created yet
1151 session.rollback()
1154def check_conn_type_null(session: Session) -> Iterable[str]:
1155 """
1156 Check nullable conn_type column in Connection table.
1158 :param session: session of the sqlalchemy
1159 """
1160 from airflow.models.connection import Connection
1162 try:
1163 n_nulls = session.scalars(select(Connection.conn_id).where(Connection.conn_type.is_(None))).all()
1164 except (exc.OperationalError, exc.ProgrammingError, exc.InternalError):
1165 # fallback if tables hasn't been created yet
1166 session.rollback()
1167 return
1169 if n_nulls:
1170 yield (
1171 "The conn_type column in the connection "
1172 "table must contain content.\n"
1173 "Make sure you don't have null "
1174 "in the conn_type column.\n"
1175 f"Null conn_type conn_id: {n_nulls}"
1176 )
1179def _format_dangling_error(source_table, target_table, invalid_count, reason):
1180 noun = "row" if invalid_count == 1 else "rows"
1181 return (
1182 f"The {source_table} table has {invalid_count} {noun} {reason}, which "
1183 f"is invalid. We could not move them out of the way because the "
1184 f"{target_table} table already exists in your database. Please either "
1185 f"drop the {target_table} table, or manually delete the invalid rows "
1186 f"from the {source_table} table."
1187 )
1190def check_run_id_null(session: Session) -> Iterable[str]:
1191 from airflow.models.dagrun import DagRun
1193 metadata = reflect_tables([DagRun], session)
1195 # We can't use the model here since it may differ from the db state due to
1196 # this function is run prior to migration. Use the reflected table instead.
1197 dagrun_table = metadata.tables.get(DagRun.__tablename__)
1198 if dagrun_table is None:
1199 return
1201 invalid_dagrun_filter = or_(
1202 dagrun_table.c.dag_id.is_(None),
1203 dagrun_table.c.run_id.is_(None),
1204 dagrun_table.c.execution_date.is_(None),
1205 )
1206 invalid_dagrun_count = session.scalar(select(func.count(dagrun_table.c.id)).where(invalid_dagrun_filter))
1207 if invalid_dagrun_count > 0:
1208 dagrun_dangling_table_name = _format_airflow_moved_table_name(dagrun_table.name, "2.2", "dangling")
1209 if dagrun_dangling_table_name in inspect(session.get_bind()).get_table_names():
1210 yield _format_dangling_error(
1211 source_table=dagrun_table.name,
1212 target_table=dagrun_dangling_table_name,
1213 invalid_count=invalid_dagrun_count,
1214 reason="with a NULL dag_id, run_id, or execution_date",
1215 )
1216 return
1218 bind = session.get_bind()
1219 dialect_name = bind.dialect.name
1220 _create_table_as(
1221 dialect_name=dialect_name,
1222 source_query=dagrun_table.select(invalid_dagrun_filter),
1223 target_table_name=dagrun_dangling_table_name,
1224 source_table_name=dagrun_table.name,
1225 session=session,
1226 )
1227 delete = dagrun_table.delete().where(invalid_dagrun_filter)
1228 session.execute(delete)
1231def _create_table_as(
1232 *,
1233 session,
1234 dialect_name: str,
1235 source_query: Query,
1236 target_table_name: str,
1237 source_table_name: str,
1238):
1239 """
1240 Create a new table with rows from query.
1242 We have to handle CTAS differently for different dialects.
1243 """
1244 if dialect_name == "mysql":
1245 # MySQL with replication needs this split in to two queries, so just do it for all MySQL
1246 # ERROR 1786 (HY000): Statement violates GTID consistency: CREATE TABLE ... SELECT.
1247 session.execute(text(f"CREATE TABLE {target_table_name} LIKE {source_table_name}"))
1248 session.execute(
1249 text(
1250 f"INSERT INTO {target_table_name} {source_query.selectable.compile(bind=session.get_bind())}"
1251 )
1252 )
1253 else:
1254 # Postgres and SQLite both support the same "CREATE TABLE a AS SELECT ..." syntax
1255 select_table = source_query.selectable.compile(bind=session.get_bind())
1256 session.execute(text(f"CREATE TABLE {target_table_name} AS {select_table}"))
1259def _move_dangling_data_to_new_table(
1260 session, source_table: Table, source_query: Query, target_table_name: str
1261):
1262 bind = session.get_bind()
1263 dialect_name = bind.dialect.name
1265 # First: Create moved rows from new table
1266 log.debug("running CTAS for table %s", target_table_name)
1267 _create_table_as(
1268 dialect_name=dialect_name,
1269 source_query=source_query,
1270 target_table_name=target_table_name,
1271 source_table_name=source_table.name,
1272 session=session,
1273 )
1274 session.commit()
1276 target_table = source_table.to_metadata(source_table.metadata, name=target_table_name)
1277 log.debug("checking whether rows were moved for table %s", target_table_name)
1278 moved_rows_exist_query = select(1).select_from(target_table).limit(1)
1279 first_moved_row = session.execute(moved_rows_exist_query).all()
1280 session.commit()
1282 if not first_moved_row:
1283 log.debug("no rows moved; dropping %s", target_table_name)
1284 # no bad rows were found; drop moved rows table.
1285 target_table.drop(bind=session.get_bind(), checkfirst=True)
1286 else:
1287 log.debug("rows moved; purging from %s", source_table.name)
1288 if dialect_name == "sqlite":
1289 pk_cols = source_table.primary_key.columns
1291 delete = source_table.delete().where(
1292 tuple_(*pk_cols).in_(session.select(*target_table.primary_key.columns).subquery())
1293 )
1294 else:
1295 delete = source_table.delete().where(
1296 and_(col == target_table.c[col.name] for col in source_table.primary_key.columns)
1297 )
1298 log.debug(delete.compile())
1299 session.execute(delete)
1300 session.commit()
1302 log.debug("exiting move function")
1305def _dangling_against_dag_run(session, source_table, dag_run):
1306 """Given a source table, we generate a subquery that will return 1 for every row that has a dagrun."""
1307 source_to_dag_run_join_cond = and_(
1308 source_table.c.dag_id == dag_run.c.dag_id,
1309 source_table.c.execution_date == dag_run.c.execution_date,
1310 )
1312 return (
1313 select(*(c.label(c.name) for c in source_table.c))
1314 .join(dag_run, source_to_dag_run_join_cond, isouter=True)
1315 .where(dag_run.c.dag_id.is_(None))
1316 )
1319def _dangling_against_task_instance(session, source_table, dag_run, task_instance):
1320 """
1321 Given a source table, generate a subquery that will return 1 for every row that has a valid task instance.
1323 This is used to identify rows that need to be removed from tables prior to adding a TI fk.
1325 Since this check is applied prior to running the migrations, we have to use different
1326 query logic depending on which revision the database is at.
1328 """
1329 if "run_id" not in task_instance.c:
1330 # db is < 2.2.0
1331 dr_join_cond = and_(
1332 source_table.c.dag_id == dag_run.c.dag_id,
1333 source_table.c.execution_date == dag_run.c.execution_date,
1334 )
1335 ti_join_cond = and_(
1336 dag_run.c.dag_id == task_instance.c.dag_id,
1337 dag_run.c.execution_date == task_instance.c.execution_date,
1338 source_table.c.task_id == task_instance.c.task_id,
1339 )
1340 else:
1341 # db is 2.2.0 <= version < 2.3.0
1342 dr_join_cond = and_(
1343 source_table.c.dag_id == dag_run.c.dag_id,
1344 source_table.c.execution_date == dag_run.c.execution_date,
1345 )
1346 ti_join_cond = and_(
1347 dag_run.c.dag_id == task_instance.c.dag_id,
1348 dag_run.c.run_id == task_instance.c.run_id,
1349 source_table.c.task_id == task_instance.c.task_id,
1350 )
1352 return (
1353 select(*(c.label(c.name) for c in source_table.c))
1354 .outerjoin(dag_run, dr_join_cond)
1355 .outerjoin(task_instance, ti_join_cond)
1356 .where(or_(task_instance.c.dag_id.is_(None), dag_run.c.dag_id.is_(None)))
1357 )
1360def _move_duplicate_data_to_new_table(
1361 session, source_table: Table, subquery: Query, uniqueness: list[str], target_table_name: str
1362):
1363 """
1364 When adding a uniqueness constraint we first should ensure that there are no duplicate rows.
1366 This function accepts a subquery that should return one record for each row with duplicates (e.g.
1367 a group by with having count(*) > 1). We select from ``source_table`` getting all rows matching the
1368 subquery result and store in ``target_table_name``. Then to purge the duplicates from the source table,
1369 we do a DELETE FROM with a join to the target table (which now contains the dupes).
1371 :param session: sqlalchemy session for metadata db
1372 :param source_table: table to purge dupes from
1373 :param subquery: the subquery that returns the duplicate rows
1374 :param uniqueness: the string list of columns used to define the uniqueness for the table. used in
1375 building the DELETE FROM join condition.
1376 :param target_table_name: name of the table in which to park the duplicate rows
1377 """
1378 bind = session.get_bind()
1379 dialect_name = bind.dialect.name
1381 query = (
1382 select(*(source_table.c[x.name].label(str(x.name)) for x in source_table.columns))
1383 .select_from(source_table)
1384 .join(subquery, and_(*(source_table.c[x] == subquery.c[x] for x in uniqueness)))
1385 )
1387 _create_table_as(
1388 session=session,
1389 dialect_name=dialect_name,
1390 source_query=query,
1391 target_table_name=target_table_name,
1392 source_table_name=source_table.name,
1393 )
1395 # we must ensure that the CTAS table is created prior to the DELETE step since we have to join to it
1396 session.commit()
1398 metadata = reflect_tables([target_table_name], session)
1399 target_table = metadata.tables[target_table_name]
1400 where_clause = and_(*(source_table.c[x] == target_table.c[x] for x in uniqueness))
1402 if dialect_name == "sqlite":
1403 subq = query.selectable.with_only_columns([text(f"{source_table}.ROWID")])
1404 delete = source_table.delete().where(column("ROWID").in_(subq))
1405 else:
1406 delete = source_table.delete(where_clause)
1408 session.execute(delete)
1411def check_bad_references(session: Session) -> Iterable[str]:
1412 """
1413 Go through each table and look for records that can't be mapped to a dag run.
1415 When we find such "dangling" rows we back them up in a special table and delete them
1416 from the main table.
1418 Starting in Airflow 2.2, we began a process of replacing `execution_date` with `run_id` in many tables.
1419 """
1420 from airflow.models.dagrun import DagRun
1421 from airflow.models.renderedtifields import RenderedTaskInstanceFields
1422 from airflow.models.taskfail import TaskFail
1423 from airflow.models.taskinstance import TaskInstance
1424 from airflow.models.taskreschedule import TaskReschedule
1425 from airflow.models.xcom import XCom
1427 @dataclass
1428 class BadReferenceConfig:
1429 """
1430 Bad reference config class.
1432 :param bad_rows_func: function that returns subquery which determines whether bad rows exist
1433 :param join_tables: table objects referenced in subquery
1434 :param ref_table: information-only identifier for categorizing the missing ref
1435 """
1437 bad_rows_func: Callable
1438 join_tables: list[str]
1439 ref_table: str
1441 missing_dag_run_config = BadReferenceConfig(
1442 bad_rows_func=_dangling_against_dag_run,
1443 join_tables=["dag_run"],
1444 ref_table="dag_run",
1445 )
1447 missing_ti_config = BadReferenceConfig(
1448 bad_rows_func=_dangling_against_task_instance,
1449 join_tables=["dag_run", "task_instance"],
1450 ref_table="task_instance",
1451 )
1453 models_list: list[tuple[MappedClassProtocol, str, BadReferenceConfig]] = [
1454 (TaskInstance, "2.2", missing_dag_run_config),
1455 (TaskReschedule, "2.2", missing_ti_config),
1456 (RenderedTaskInstanceFields, "2.3", missing_ti_config),
1457 (TaskFail, "2.3", missing_ti_config),
1458 (XCom, "2.3", missing_ti_config),
1459 ]
1460 metadata = reflect_tables([*(x[0] for x in models_list), DagRun, TaskInstance], session)
1462 if (
1463 not metadata.tables
1464 or metadata.tables.get(DagRun.__tablename__) is None
1465 or metadata.tables.get(TaskInstance.__tablename__) is None
1466 ):
1467 # Key table doesn't exist -- likely empty DB.
1468 return
1470 existing_table_names = set(inspect(session.get_bind()).get_table_names())
1471 errored = False
1473 for model, change_version, bad_ref_cfg in models_list:
1474 log.debug("checking model %s", model.__tablename__)
1475 # We can't use the model here since it may differ from the db state due to
1476 # this function is run prior to migration. Use the reflected table instead.
1477 source_table = metadata.tables.get(model.__tablename__) # type: ignore
1478 if source_table is None:
1479 continue
1481 # Migration already applied, don't check again.
1482 if "run_id" in source_table.columns:
1483 continue
1485 func_kwargs = {x: metadata.tables[x] for x in bad_ref_cfg.join_tables}
1486 bad_rows_query = bad_ref_cfg.bad_rows_func(session, source_table, **func_kwargs)
1488 dangling_table_name = _format_airflow_moved_table_name(source_table.name, change_version, "dangling")
1489 if dangling_table_name in existing_table_names:
1490 invalid_row_count = get_query_count(bad_rows_query, session=session)
1491 if invalid_row_count:
1492 yield _format_dangling_error(
1493 source_table=source_table.name,
1494 target_table=dangling_table_name,
1495 invalid_count=invalid_row_count,
1496 reason=f"without a corresponding {bad_ref_cfg.ref_table} row",
1497 )
1498 errored = True
1499 continue
1501 log.debug("moving data for table %s", source_table.name)
1502 _move_dangling_data_to_new_table(
1503 session,
1504 source_table,
1505 bad_rows_query,
1506 dangling_table_name,
1507 )
1509 if errored:
1510 session.rollback()
1511 else:
1512 session.commit()
1515@provide_session
1516def _check_migration_errors(session: Session = NEW_SESSION) -> Iterable[str]:
1517 """:session: session of the sqlalchemy."""
1518 check_functions: tuple[Callable[..., Iterable[str]], ...] = (
1519 check_conn_id_duplicates,
1520 check_conn_type_null,
1521 check_run_id_null,
1522 check_bad_references,
1523 check_username_duplicates,
1524 )
1525 for check_fn in check_functions:
1526 log.debug("running check function %s", check_fn.__name__)
1527 yield from check_fn(session=session)
1530def _offline_migration(migration_func: Callable, config, revision):
1531 with warnings.catch_warnings():
1532 warnings.simplefilter("ignore")
1533 logging.disable(logging.CRITICAL)
1534 migration_func(config, revision, sql=True)
1535 logging.disable(logging.NOTSET)
1538def print_happy_cat(message):
1539 if sys.stdout.isatty():
1540 size = os.get_terminal_size().columns
1541 else:
1542 size = 0
1543 print(message.center(size))
1544 print("""/\\_/\\""".center(size))
1545 print("""(='_' )""".center(size))
1546 print("""(,(") (")""".center(size))
1547 print("""^^^""".center(size))
1548 return
1551def _revision_greater(config, this_rev, base_rev):
1552 # Check if there is history between the revisions and the start revision
1553 # This ensures that the revisions are above `min_revision`
1554 script = _get_script_object(config)
1555 try:
1556 list(script.revision_map.iterate_revisions(upper=this_rev, lower=base_rev))
1557 return True
1558 except Exception:
1559 return False
1562def _revisions_above_min_for_offline(config, revisions) -> None:
1563 """
1564 Check that all supplied revision ids are above the minimum revision for the dialect.
1566 :param config: Alembic config
1567 :param revisions: list of Alembic revision ids
1568 :return: None
1569 """
1570 dbname = settings.engine.dialect.name
1571 if dbname == "sqlite":
1572 raise SystemExit("Offline migration not supported for SQLite.")
1573 min_version, min_revision = ("2.2.0", "7b2661a43ba3") if dbname == "mssql" else ("2.0.0", "e959f08ac86c")
1575 # Check if there is history between the revisions and the start revision
1576 # This ensures that the revisions are above `min_revision`
1577 for rev in revisions:
1578 if not _revision_greater(config, rev, min_revision):
1579 raise ValueError(
1580 f"Error while checking history for revision range {min_revision}:{rev}. "
1581 f"Check that {rev} is a valid revision. "
1582 f"For dialect {dbname!r}, supported revision for offline migration is from {min_revision} "
1583 f"which corresponds to Airflow {min_version}."
1584 )
1587@provide_session
1588def upgradedb(
1589 *,
1590 to_revision: str | None = None,
1591 from_revision: str | None = None,
1592 show_sql_only: bool = False,
1593 reserialize_dags: bool = True,
1594 session: Session = NEW_SESSION,
1595 use_migration_files: bool = False,
1596):
1597 """
1598 Upgrades the DB.
1600 :param to_revision: Optional Alembic revision ID to upgrade *to*.
1601 If omitted, upgrades to latest revision.
1602 :param from_revision: Optional Alembic revision ID to upgrade *from*.
1603 Not compatible with ``sql_only=False``.
1604 :param show_sql_only: if True, migration statements will be printed but not executed.
1605 :param session: sqlalchemy session with connection to Airflow metadata database
1606 :return: None
1607 """
1608 if from_revision and not show_sql_only:
1609 raise AirflowException("`from_revision` only supported with `sql_only=True`.")
1611 # alembic adds significant import time, so we import it lazily
1612 if not settings.SQL_ALCHEMY_CONN:
1613 raise RuntimeError("The settings.SQL_ALCHEMY_CONN not set. This is a critical assertion.")
1614 from alembic import command
1616 import_all_models()
1618 config = _get_alembic_config()
1620 if show_sql_only:
1621 if not from_revision:
1622 from_revision = _get_current_revision(session)
1624 if not to_revision:
1625 script = _get_script_object()
1626 to_revision = script.get_current_head()
1628 if to_revision == from_revision:
1629 print_happy_cat("No migrations to apply; nothing to do.")
1630 return
1632 if not _revision_greater(config, to_revision, from_revision):
1633 raise ValueError(
1634 f"Requested *to* revision {to_revision} is older than *from* revision {from_revision}. "
1635 "Please check your requested versions / revisions."
1636 )
1637 _revisions_above_min_for_offline(config=config, revisions=[from_revision, to_revision])
1639 _offline_migration(command.upgrade, config, f"{from_revision}:{to_revision}")
1640 return # only running sql; our job is done
1642 errors_seen = False
1643 for err in _check_migration_errors(session=session):
1644 if not errors_seen:
1645 log.error("Automatic migration is not available")
1646 errors_seen = True
1647 log.error("%s", err)
1649 if errors_seen:
1650 exit(1)
1652 if not to_revision and not _get_current_revision(session=session) and not use_migration_files:
1653 # Don't load default connections
1654 # New DB; initialize and exit
1655 initdb(session=session, load_connections=False)
1656 return
1657 with create_global_lock(session=session, lock=DBLocks.MIGRATIONS):
1658 import sqlalchemy.pool
1660 log.info("Creating tables")
1661 val = os.environ.get("AIRFLOW__DATABASE__SQL_ALCHEMY_MAX_SIZE")
1662 try:
1663 # Reconfigure the ORM to use _EXACTLY_ one connection, otherwise some db engines hang forever
1664 # trying to ALTER TABLEs
1665 os.environ["AIRFLOW__DATABASE__SQL_ALCHEMY_MAX_SIZE"] = "1"
1666 settings.reconfigure_orm(pool_class=sqlalchemy.pool.SingletonThreadPool)
1667 command.upgrade(config, revision=to_revision or "heads")
1668 finally:
1669 if val is None:
1670 os.environ.pop("AIRFLOW__DATABASE__SQL_ALCHEMY_MAX_SIZE")
1671 else:
1672 os.environ["AIRFLOW__DATABASE__SQL_ALCHEMY_MAX_SIZE"] = val
1673 settings.reconfigure_orm()
1675 if reserialize_dags:
1676 _reserialize_dags(session=session)
1677 add_default_pool_if_not_exists(session=session)
1678 synchronize_log_template(session=session)
1681@provide_session
1682def resetdb(session: Session = NEW_SESSION, skip_init: bool = False, use_migration_files: bool = False):
1683 """Clear out the database."""
1684 if not settings.engine:
1685 raise RuntimeError("The settings.engine must be set. This is a critical assertion")
1686 log.info("Dropping tables that exist")
1688 import_all_models()
1690 connection = settings.engine.connect()
1692 with create_global_lock(session=session, lock=DBLocks.MIGRATIONS), connection.begin():
1693 drop_airflow_models(connection)
1694 drop_airflow_moved_tables(connection)
1696 if not skip_init:
1697 initdb(session=session, use_migration_files=use_migration_files)
1700@provide_session
1701def bootstrap_dagbag(session: Session = NEW_SESSION):
1702 from airflow.models.dag import DAG
1703 from airflow.models.dagbag import DagBag
1705 dagbag = DagBag()
1706 # Save DAGs in the ORM
1707 dagbag.sync_to_db(session=session)
1709 # Deactivate the unknown ones
1710 DAG.deactivate_unknown_dags(dagbag.dags.keys(), session=session)
1713@provide_session
1714def downgrade(*, to_revision, from_revision=None, show_sql_only=False, session: Session = NEW_SESSION):
1715 """
1716 Downgrade the airflow metastore schema to a prior version.
1718 :param to_revision: The alembic revision to downgrade *to*.
1719 :param show_sql_only: if True, print sql statements but do not run them
1720 :param from_revision: if supplied, alembic revision to dawngrade *from*. This may only
1721 be used in conjunction with ``sql=True`` because if we actually run the commands,
1722 we should only downgrade from the *current* revision.
1723 :param session: sqlalchemy session for connection to airflow metadata database
1724 """
1725 if from_revision and not show_sql_only:
1726 raise ValueError(
1727 "`from_revision` can't be combined with `sql=False`. When actually "
1728 "applying a downgrade (instead of just generating sql), we always "
1729 "downgrade from current revision."
1730 )
1732 if not settings.SQL_ALCHEMY_CONN:
1733 raise RuntimeError("The settings.SQL_ALCHEMY_CONN not set.")
1735 # alembic adds significant import time, so we import it lazily
1736 from alembic import command
1738 log.info("Attempting downgrade to revision %s", to_revision)
1739 config = _get_alembic_config()
1741 with create_global_lock(session=session, lock=DBLocks.MIGRATIONS):
1742 if show_sql_only:
1743 log.warning("Generating sql scripts for manual migration.")
1744 if not from_revision:
1745 from_revision = _get_current_revision(session)
1746 revision_range = f"{from_revision}:{to_revision}"
1747 _offline_migration(command.downgrade, config=config, revision=revision_range)
1748 else:
1749 log.info("Applying downgrade migrations.")
1750 command.downgrade(config, revision=to_revision, sql=show_sql_only)
1753def drop_airflow_models(connection):
1754 """
1755 Drop all airflow models.
1757 :param connection: SQLAlchemy Connection
1758 :return: None
1759 """
1760 from airflow.models.base import Base
1761 from airflow.providers.fab.auth_manager.models import Model
1763 Base.metadata.drop_all(connection)
1764 Model.metadata.drop_all(connection)
1765 db = _get_flask_db(connection.engine.url)
1766 db.drop_all()
1767 # alembic adds significant import time, so we import it lazily
1768 from alembic.migration import MigrationContext
1770 migration_ctx = MigrationContext.configure(connection)
1771 version = migration_ctx._version
1772 if inspect(connection).has_table(version.name):
1773 version.drop(connection)
1776def drop_airflow_moved_tables(connection):
1777 from airflow.models.base import Base
1778 from airflow.settings import AIRFLOW_MOVED_TABLE_PREFIX
1780 tables = set(inspect(connection).get_table_names())
1781 to_delete = [Table(x, Base.metadata) for x in tables if x.startswith(AIRFLOW_MOVED_TABLE_PREFIX)]
1782 for tbl in to_delete:
1783 tbl.drop(settings.engine, checkfirst=False)
1784 Base.metadata.remove(tbl)
1787@provide_session
1788def check(session: Session = NEW_SESSION):
1789 """
1790 Check if the database works.
1792 :param session: session of the sqlalchemy
1793 """
1794 session.execute(text("select 1 as is_alive;"))
1795 log.info("Connection successful.")
1798@enum.unique
1799class DBLocks(enum.IntEnum):
1800 """
1801 Cross-db Identifiers for advisory global database locks.
1803 Postgres uses int64 lock ids so we use the integer value, MySQL uses names, so we
1804 call ``str()`, which is implemented using the ``_name_`` field.
1805 """
1807 MIGRATIONS = enum.auto()
1808 SCHEDULER_CRITICAL_SECTION = enum.auto()
1810 def __str__(self):
1811 return f"airflow_{self._name_}"
1814@contextlib.contextmanager
1815def create_global_lock(
1816 session: Session,
1817 lock: DBLocks,
1818 lock_timeout: int = 1800,
1819) -> Generator[None, None, None]:
1820 """Contextmanager that will create and teardown a global db lock."""
1821 conn = session.get_bind().connect()
1822 dialect = conn.dialect
1823 try:
1824 if dialect.name == "postgresql":
1825 conn.execute(text("SET LOCK_TIMEOUT to :timeout"), {"timeout": lock_timeout})
1826 conn.execute(text("SELECT pg_advisory_lock(:id)"), {"id": lock.value})
1827 elif dialect.name == "mysql" and dialect.server_version_info >= (5, 6):
1828 conn.execute(text("SELECT GET_LOCK(:id, :timeout)"), {"id": str(lock), "timeout": lock_timeout})
1830 yield
1831 finally:
1832 if dialect.name == "postgresql":
1833 conn.execute(text("SET LOCK_TIMEOUT TO DEFAULT"))
1834 (unlocked,) = conn.execute(text("SELECT pg_advisory_unlock(:id)"), {"id": lock.value}).fetchone()
1835 if not unlocked:
1836 raise RuntimeError("Error releasing DB lock!")
1837 elif dialect.name == "mysql" and dialect.server_version_info >= (5, 6):
1838 conn.execute(text("select RELEASE_LOCK(:id)"), {"id": str(lock)})
1841def compare_type(context, inspected_column, metadata_column, inspected_type, metadata_type):
1842 """
1843 Compare types between ORM and DB .
1845 return False if the metadata_type is the same as the inspected_type
1846 or None to allow the default implementation to compare these
1847 types. a return value of True means the two types do not
1848 match and should result in a type change operation.
1849 """
1850 if context.dialect.name == "mysql":
1851 from sqlalchemy import String
1852 from sqlalchemy.dialects import mysql
1854 if isinstance(inspected_type, mysql.VARCHAR) and isinstance(metadata_type, String):
1855 # This is a hack to get around MySQL VARCHAR collation
1856 # not being possible to change from utf8_bin to utf8mb3_bin.
1857 # We only make sure lengths are the same
1858 if inspected_type.length != metadata_type.length:
1859 return True
1860 return False
1861 return None
1864def compare_server_default(
1865 context, inspected_column, metadata_column, inspected_default, metadata_default, rendered_metadata_default
1866):
1867 """
1868 Compare server defaults between ORM and DB .
1870 return True if the defaults are different, False if not, or None to allow the default implementation
1871 to compare these defaults
1873 In SQLite: task_instance.map_index & task_reschedule.map_index
1874 are not comparing accurately. Sometimes they are equal, sometimes they are not.
1875 Alembic warned that this feature has varied accuracy depending on backends.
1876 See: (https://alembic.sqlalchemy.org/en/latest/api/runtime.html#alembic.runtime.
1877 environment.EnvironmentContext.configure.params.compare_server_default)
1878 """
1879 dialect_name = context.connection.dialect.name
1880 if dialect_name in ["sqlite"]:
1881 return False
1882 if (
1883 dialect_name == "mysql"
1884 and metadata_column.name == "pool_slots"
1885 and metadata_column.table.name == "task_instance"
1886 ):
1887 # We removed server_default value in ORM to avoid expensive migration
1888 # (it was removed in postgres DB in migration head 7b2661a43ba3 ).
1889 # As a side note, server default value here was only actually needed for the migration
1890 # where we added the column in the first place -- now that it exists and all
1891 # existing rows are populated with a value this server default is never used.
1892 return False
1893 return None
1896def get_sqla_model_classes():
1897 """
1898 Get all SQLAlchemy class mappers.
1900 SQLAlchemy < 1.4 does not support registry.mappers so we use
1901 try/except to handle it.
1902 """
1903 from airflow.models.base import Base
1905 try:
1906 return [mapper.class_ for mapper in Base.registry.mappers]
1907 except AttributeError:
1908 return Base._decl_class_registry.values()
1911def get_query_count(query_stmt: Select, *, session: Session) -> int:
1912 """Get count of a query.
1914 A SELECT COUNT() FROM is issued against the subquery built from the
1915 given statement. The ORDER BY clause is stripped from the statement
1916 since it's unnecessary for COUNT, and can impact query planning and
1917 degrade performance.
1919 :meta private:
1920 """
1921 count_stmt = select(func.count()).select_from(query_stmt.order_by(None).subquery())
1922 return session.scalar(count_stmt)
1925def check_query_exists(query_stmt: Select, *, session: Session) -> bool:
1926 """Check whether there is at least one row matching a query.
1928 A SELECT 1 FROM is issued against the subquery built from the given
1929 statement. The ORDER BY clause is stripped from the statement since it's
1930 unnecessary, and can impact query planning and degrade performance.
1932 :meta private:
1933 """
1934 count_stmt = select(literal(True)).select_from(query_stmt.order_by(None).subquery())
1935 return session.scalar(count_stmt)
1938def exists_query(*where: ClauseElement, session: Session) -> bool:
1939 """Check whether there is at least one row matching given clauses.
1941 This does a SELECT 1 WHERE ... LIMIT 1 and check the result.
1943 :meta private:
1944 """
1945 stmt = select(literal(True)).where(*where).limit(1)
1946 return session.scalar(stmt) is not None
1949@attrs.define(slots=True)
1950class LazySelectSequence(Sequence[T]):
1951 """List-like interface to lazily access a database model query.
1953 The intended use case is inside a task execution context, where we manage an
1954 active SQLAlchemy session in the background.
1956 This is an abstract base class. Each use case should subclass, and implement
1957 the following static methods:
1959 * ``_rebuild_select`` is called when a lazy sequence is unpickled. Since it
1960 is not easy to pickle SQLAlchemy constructs, this class serializes the
1961 SELECT statements into plain text to storage. This method is called on
1962 deserialization to convert the textual clause back into an ORM SELECT.
1963 * ``_process_row`` is called when an item is accessed. The lazy sequence
1964 uses ``session.execute()`` to fetch rows from the database, and this
1965 method should know how to process each row into a value.
1967 :meta private:
1968 """
1970 _select_asc: ClauseElement
1971 _select_desc: ClauseElement
1972 _session: Session = attrs.field(kw_only=True, factory=get_current_task_instance_session)
1973 _len: int | None = attrs.field(init=False, default=None)
1975 @classmethod
1976 def from_select(
1977 cls,
1978 select: Select,
1979 *,
1980 order_by: Sequence[ClauseElement],
1981 session: Session | None = None,
1982 ) -> Self:
1983 s1 = select
1984 for col in order_by:
1985 s1 = s1.order_by(col.asc())
1986 s2 = select
1987 for col in order_by:
1988 s2 = s2.order_by(col.desc())
1989 return cls(s1, s2, session=session or get_current_task_instance_session())
1991 @staticmethod
1992 def _rebuild_select(stmt: TextClause) -> Select:
1993 """Rebuild a textual statement into an ORM-configured SELECT statement.
1995 This should do something like ``select(field).from_statement(stmt)`` to
1996 reconfigure ORM information to the textual SQL statement.
1997 """
1998 raise NotImplementedError
2000 @staticmethod
2001 def _process_row(row: Row) -> T:
2002 """Process a SELECT-ed row into the end value."""
2003 raise NotImplementedError
2005 def __repr__(self) -> str:
2006 counter = "item" if (length := len(self)) == 1 else "items"
2007 return f"LazySelectSequence([{length} {counter}])"
2009 def __str__(self) -> str:
2010 counter = "item" if (length := len(self)) == 1 else "items"
2011 return f"LazySelectSequence([{length} {counter}])"
2013 def __getstate__(self) -> Any:
2014 # We don't want to go to the trouble of serializing SQLAlchemy objects.
2015 # Converting the statement into a SQL string is the best we can get.
2016 # The literal_binds compile argument inlines all the values into the SQL
2017 # string to simplify cross-process commuinication as much as possible.
2018 # Theoratically we can do the same for count(), but I think it should be
2019 # performant enough to calculate only that eagerly.
2020 s1 = str(self._select_asc.compile(self._session.get_bind(), compile_kwargs={"literal_binds": True}))
2021 s2 = str(self._select_desc.compile(self._session.get_bind(), compile_kwargs={"literal_binds": True}))
2022 return (s1, s2, len(self))
2024 def __setstate__(self, state: Any) -> None:
2025 s1, s2, self._len = state
2026 self._select_asc = self._rebuild_select(text(s1))
2027 self._select_desc = self._rebuild_select(text(s2))
2028 self._session = get_current_task_instance_session()
2030 def __bool__(self) -> bool:
2031 return check_query_exists(self._select_asc, session=self._session)
2033 def __eq__(self, other: Any) -> bool:
2034 if not isinstance(other, collections.abc.Sequence):
2035 return NotImplemented
2036 z = itertools.zip_longest(iter(self), iter(other), fillvalue=object())
2037 return all(x == y for x, y in z)
2039 def __reversed__(self) -> Iterator[T]:
2040 return iter(self._process_row(r) for r in self._session.execute(self._select_desc))
2042 def __iter__(self) -> Iterator[T]:
2043 return iter(self._process_row(r) for r in self._session.execute(self._select_asc))
2045 def __len__(self) -> int:
2046 if self._len is None:
2047 self._len = get_query_count(self._select_asc, session=self._session)
2048 return self._len
2050 @overload
2051 def __getitem__(self, key: int) -> T: ...
2053 @overload
2054 def __getitem__(self, key: slice) -> Sequence[T]: ...
2056 def __getitem__(self, key: int | slice) -> T | Sequence[T]:
2057 if isinstance(key, int):
2058 if key >= 0:
2059 stmt = self._select_asc.offset(key)
2060 else:
2061 stmt = self._select_desc.offset(-1 - key)
2062 if (row := self._session.execute(stmt.limit(1)).one_or_none()) is None:
2063 raise IndexError(key)
2064 return self._process_row(row)
2065 elif isinstance(key, slice):
2066 # This implements the slicing syntax. We want to optimize negative
2067 # slicing (e.g. seq[-10:]) by not doing an additional COUNT query
2068 # if possible. We can do this unless the start and stop have
2069 # different signs (i.e. one is positive and another negative).
2070 start, stop, reverse = _coerce_slice(key)
2071 if start >= 0:
2072 if stop is None:
2073 stmt = self._select_asc.offset(start)
2074 elif stop >= 0:
2075 stmt = self._select_asc.slice(start, stop)
2076 else:
2077 stmt = self._select_asc.slice(start, len(self) + stop)
2078 rows = [self._process_row(row) for row in self._session.execute(stmt)]
2079 if reverse:
2080 rows.reverse()
2081 else:
2082 if stop is None:
2083 stmt = self._select_desc.limit(-start)
2084 elif stop < 0:
2085 stmt = self._select_desc.slice(-stop, -start)
2086 else:
2087 stmt = self._select_desc.slice(len(self) - stop, -start)
2088 rows = [self._process_row(row) for row in self._session.execute(stmt)]
2089 if not reverse:
2090 rows.reverse()
2091 return rows
2092 raise TypeError(f"Sequence indices must be integers or slices, not {type(key).__name__}")
2095def _coerce_index(value: Any) -> int | None:
2096 """Check slice attribute's type and convert it to int.
2098 See CPython documentation on this:
2099 https://docs.python.org/3/reference/datamodel.html#object.__index__
2100 """
2101 if value is None or isinstance(value, int):
2102 return value
2103 if (index := getattr(value, "__index__", None)) is not None:
2104 return index()
2105 raise TypeError("slice indices must be integers or None or have an __index__ method")
2108def _coerce_slice(key: slice) -> tuple[int, int | None, bool]:
2109 """Check slice content and convert it for SQL.
2111 See CPython documentation on this:
2112 https://docs.python.org/3/reference/datamodel.html#slice-objects
2113 """
2114 if key.step is None or key.step == 1:
2115 reverse = False
2116 elif key.step == -1:
2117 reverse = True
2118 else:
2119 raise ValueError("non-trivial slice step not supported")
2120 return _coerce_index(key.start) or 0, _coerce_index(key.stop), reverse