Coverage for /pythoncovmergedfiles/medio/medio/src/airflow/airflow/models/dagcode.py: 47%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements. See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership. The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License. You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied. See the License for the
15# specific language governing permissions and limitations
16# under the License.
17from __future__ import annotations
19import logging
20import os
21import struct
22from datetime import datetime
23from typing import TYPE_CHECKING, Collection, Iterable
25from sqlalchemy import BigInteger, Column, String, Text, delete, select
26from sqlalchemy.dialects.mysql import MEDIUMTEXT
27from sqlalchemy.sql.expression import literal
29from airflow.exceptions import AirflowException, DagCodeNotFound
30from airflow.models.base import Base
31from airflow.utils import timezone
32from airflow.utils.file import correct_maybe_zipped, open_maybe_zipped
33from airflow.utils.session import NEW_SESSION, provide_session
34from airflow.utils.sqlalchemy import UtcDateTime
36if TYPE_CHECKING:
37 from sqlalchemy.orm import Session
39log = logging.getLogger(__name__)
42class DagCode(Base):
43 """A table for DAGs code.
45 dag_code table contains code of DAG files synchronized by scheduler.
47 For details on dag serialization see SerializedDagModel
48 """
50 __tablename__ = "dag_code"
52 fileloc_hash = Column(BigInteger, nullable=False, primary_key=True, autoincrement=False)
53 fileloc = Column(String(2000), nullable=False)
54 # The max length of fileloc exceeds the limit of indexing.
55 last_updated = Column(UtcDateTime, nullable=False)
56 source_code = Column(Text().with_variant(MEDIUMTEXT(), "mysql"), nullable=False)
58 def __init__(self, full_filepath: str, source_code: str | None = None):
59 self.fileloc = full_filepath
60 self.fileloc_hash = DagCode.dag_fileloc_hash(self.fileloc)
61 self.last_updated = timezone.utcnow()
62 self.source_code = source_code or DagCode.code(self.fileloc)
64 @provide_session
65 def sync_to_db(self, session: Session = NEW_SESSION) -> None:
66 """Write code into database.
68 :param session: ORM Session
69 """
70 self.bulk_sync_to_db([self.fileloc], session)
72 @classmethod
73 @provide_session
74 def bulk_sync_to_db(cls, filelocs: Iterable[str], session: Session = NEW_SESSION) -> None:
75 """Write code in bulk into database.
77 :param filelocs: file paths of DAGs to sync
78 :param session: ORM Session
79 """
80 filelocs = set(filelocs)
81 filelocs_to_hashes = {fileloc: DagCode.dag_fileloc_hash(fileloc) for fileloc in filelocs}
82 existing_orm_dag_codes = session.scalars(
83 select(DagCode)
84 .filter(DagCode.fileloc_hash.in_(filelocs_to_hashes.values()))
85 .with_for_update(of=DagCode)
86 ).all()
88 if existing_orm_dag_codes:
89 existing_orm_dag_codes_map = {
90 orm_dag_code.fileloc: orm_dag_code for orm_dag_code in existing_orm_dag_codes
91 }
92 else:
93 existing_orm_dag_codes_map = {}
95 existing_orm_dag_codes_by_fileloc_hashes = {orm.fileloc_hash: orm for orm in existing_orm_dag_codes}
96 existing_orm_filelocs = {orm.fileloc for orm in existing_orm_dag_codes_by_fileloc_hashes.values()}
97 if not existing_orm_filelocs.issubset(filelocs):
98 conflicting_filelocs = existing_orm_filelocs.difference(filelocs)
99 hashes_to_filelocs = {DagCode.dag_fileloc_hash(fileloc): fileloc for fileloc in filelocs}
100 message = ""
101 for fileloc in conflicting_filelocs:
102 filename = hashes_to_filelocs[DagCode.dag_fileloc_hash(fileloc)]
103 message += (
104 f"Filename '{filename}' causes a hash collision in the "
105 f"database with '{fileloc}'. Please rename the file."
106 )
107 raise AirflowException(message)
109 existing_filelocs = {dag_code.fileloc for dag_code in existing_orm_dag_codes}
110 missing_filelocs = filelocs.difference(existing_filelocs)
112 for fileloc in missing_filelocs:
113 orm_dag_code = DagCode(fileloc, cls._get_code_from_file(fileloc))
114 session.add(orm_dag_code)
116 for fileloc in existing_filelocs:
117 current_version = existing_orm_dag_codes_by_fileloc_hashes[filelocs_to_hashes[fileloc]]
118 file_mod_time = datetime.fromtimestamp(
119 os.path.getmtime(correct_maybe_zipped(fileloc)), tz=timezone.utc
120 )
122 if file_mod_time > current_version.last_updated:
123 orm_dag_code = existing_orm_dag_codes_map[fileloc]
124 orm_dag_code.last_updated = file_mod_time
125 orm_dag_code.source_code = cls._get_code_from_file(orm_dag_code.fileloc)
126 session.merge(orm_dag_code)
128 @classmethod
129 @provide_session
130 def remove_deleted_code(
131 cls,
132 alive_dag_filelocs: Collection[str],
133 processor_subdir: str,
134 session: Session = NEW_SESSION,
135 ) -> None:
136 """Delete code not included in alive_dag_filelocs.
138 :param alive_dag_filelocs: file paths of alive DAGs
139 :param processor_subdir: dag processor subdir
140 :param session: ORM Session
141 """
142 alive_fileloc_hashes = [cls.dag_fileloc_hash(fileloc) for fileloc in alive_dag_filelocs]
144 log.debug("Deleting code from %s table ", cls.__tablename__)
146 session.execute(
147 delete(cls)
148 .where(
149 cls.fileloc_hash.notin_(alive_fileloc_hashes),
150 cls.fileloc.notin_(alive_dag_filelocs),
151 cls.fileloc.contains(processor_subdir),
152 )
153 .execution_options(synchronize_session="fetch")
154 )
156 @classmethod
157 @provide_session
158 def has_dag(cls, fileloc: str, session: Session = NEW_SESSION) -> bool:
159 """Check a file exist in dag_code table.
161 :param fileloc: the file to check
162 :param session: ORM Session
163 """
164 fileloc_hash = cls.dag_fileloc_hash(fileloc)
165 return (
166 session.scalars(select(literal(True)).where(cls.fileloc_hash == fileloc_hash)).one_or_none()
167 is not None
168 )
170 @classmethod
171 def get_code_by_fileloc(cls, fileloc: str) -> str:
172 """Return source code for a given fileloc.
174 :param fileloc: file path of a DAG
175 :return: source code as string
176 """
177 return cls.code(fileloc)
179 @classmethod
180 @provide_session
181 def code(cls, fileloc, session: Session = NEW_SESSION) -> str:
182 """Return source code for this DagCode object.
184 :return: source code as string
185 """
186 return cls._get_code_from_db(fileloc, session)
188 @staticmethod
189 def _get_code_from_file(fileloc):
190 with open_maybe_zipped(fileloc, "r") as f:
191 code = f.read()
192 return code
194 @classmethod
195 @provide_session
196 def _get_code_from_db(cls, fileloc, session: Session = NEW_SESSION) -> str:
197 dag_code = session.scalar(select(cls).where(cls.fileloc_hash == cls.dag_fileloc_hash(fileloc)))
198 if not dag_code:
199 raise DagCodeNotFound()
200 else:
201 code = dag_code.source_code
202 return code
204 @staticmethod
205 def dag_fileloc_hash(full_filepath: str) -> int:
206 """Hashing file location for indexing.
208 :param full_filepath: full filepath of DAG file
209 :return: hashed full_filepath
210 """
211 # Hashing is needed because the length of fileloc is 2000 as an Airflow convention,
212 # which is over the limit of indexing.
213 import hashlib
215 # Only 7 bytes because MySQL BigInteger can hold only 8 bytes (signed).
216 return struct.unpack(">Q", hashlib.sha1(full_filepath.encode("utf-8")).digest()[-8:])[0] >> 8