Coverage for /pythoncovmergedfiles/medio/medio/src/airflow/build/lib/airflow/models/dagcode.py: 47%
97 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-07 06:35 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-07 06:35 +0000
1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements. See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership. The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License. You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied. See the License for the
15# specific language governing permissions and limitations
16# under the License.
17from __future__ import annotations
19import logging
20import os
21import struct
22from datetime import datetime
23from typing import Iterable
25from sqlalchemy import BigInteger, Column, String, Text, delete
26from sqlalchemy.dialects.mysql import MEDIUMTEXT
27from sqlalchemy.orm import Session
28from sqlalchemy.sql.expression import literal
30from airflow.exceptions import AirflowException, DagCodeNotFound
31from airflow.models.base import Base
32from airflow.utils import timezone
33from airflow.utils.file import correct_maybe_zipped, open_maybe_zipped
34from airflow.utils.session import NEW_SESSION, provide_session
35from airflow.utils.sqlalchemy import UtcDateTime
37log = logging.getLogger(__name__)
40class DagCode(Base):
41 """A table for DAGs code.
43 dag_code table contains code of DAG files synchronized by scheduler.
45 For details on dag serialization see SerializedDagModel
46 """
48 __tablename__ = "dag_code"
50 fileloc_hash = Column(BigInteger, nullable=False, primary_key=True, autoincrement=False)
51 fileloc = Column(String(2000), nullable=False)
52 # The max length of fileloc exceeds the limit of indexing.
53 last_updated = Column(UtcDateTime, nullable=False)
54 source_code = Column(Text().with_variant(MEDIUMTEXT(), "mysql"), nullable=False)
56 def __init__(self, full_filepath: str, source_code: str | None = None):
57 self.fileloc = full_filepath
58 self.fileloc_hash = DagCode.dag_fileloc_hash(self.fileloc)
59 self.last_updated = timezone.utcnow()
60 self.source_code = source_code or DagCode.code(self.fileloc)
62 @provide_session
63 def sync_to_db(self, session: Session = NEW_SESSION) -> None:
64 """Writes code into database.
66 :param session: ORM Session
67 """
68 self.bulk_sync_to_db([self.fileloc], session)
70 @classmethod
71 @provide_session
72 def bulk_sync_to_db(cls, filelocs: Iterable[str], session: Session = NEW_SESSION) -> None:
73 """Writes code in bulk into database.
75 :param filelocs: file paths of DAGs to sync
76 :param session: ORM Session
77 """
78 filelocs = set(filelocs)
79 filelocs_to_hashes = {fileloc: DagCode.dag_fileloc_hash(fileloc) for fileloc in filelocs}
80 existing_orm_dag_codes = (
81 session.query(DagCode)
82 .filter(DagCode.fileloc_hash.in_(filelocs_to_hashes.values()))
83 .with_for_update(of=DagCode)
84 .all()
85 )
87 if existing_orm_dag_codes:
88 existing_orm_dag_codes_map = {
89 orm_dag_code.fileloc: orm_dag_code for orm_dag_code in existing_orm_dag_codes
90 }
91 else:
92 existing_orm_dag_codes_map = {}
94 existing_orm_dag_codes_by_fileloc_hashes = {orm.fileloc_hash: orm for orm in existing_orm_dag_codes}
95 existing_orm_filelocs = {orm.fileloc for orm in existing_orm_dag_codes_by_fileloc_hashes.values()}
96 if not existing_orm_filelocs.issubset(filelocs):
97 conflicting_filelocs = existing_orm_filelocs.difference(filelocs)
98 hashes_to_filelocs = {DagCode.dag_fileloc_hash(fileloc): fileloc for fileloc in filelocs}
99 message = ""
100 for fileloc in conflicting_filelocs:
101 filename = hashes_to_filelocs[DagCode.dag_fileloc_hash(fileloc)]
102 message += (
103 f"Filename '{filename}' causes a hash collision in the "
104 f"database with '{fileloc}'. Please rename the file."
105 )
106 raise AirflowException(message)
108 existing_filelocs = {dag_code.fileloc for dag_code in existing_orm_dag_codes}
109 missing_filelocs = filelocs.difference(existing_filelocs)
111 for fileloc in missing_filelocs:
112 orm_dag_code = DagCode(fileloc, cls._get_code_from_file(fileloc))
113 session.add(orm_dag_code)
115 for fileloc in existing_filelocs:
116 current_version = existing_orm_dag_codes_by_fileloc_hashes[filelocs_to_hashes[fileloc]]
117 file_mod_time = datetime.fromtimestamp(
118 os.path.getmtime(correct_maybe_zipped(fileloc)), tz=timezone.utc
119 )
121 if file_mod_time > current_version.last_updated:
122 orm_dag_code = existing_orm_dag_codes_map[fileloc]
123 orm_dag_code.last_updated = file_mod_time
124 orm_dag_code.source_code = cls._get_code_from_file(orm_dag_code.fileloc)
125 session.merge(orm_dag_code)
127 @classmethod
128 @provide_session
129 def remove_deleted_code(cls, alive_dag_filelocs: list[str], session: Session = NEW_SESSION) -> None:
130 """Deletes code not included in alive_dag_filelocs.
132 :param alive_dag_filelocs: file paths of alive DAGs
133 :param session: ORM Session
134 """
135 alive_fileloc_hashes = [cls.dag_fileloc_hash(fileloc) for fileloc in alive_dag_filelocs]
137 log.debug("Deleting code from %s table ", cls.__tablename__)
139 session.execute(
140 delete(cls)
141 .where(cls.fileloc_hash.notin_(alive_fileloc_hashes), cls.fileloc.notin_(alive_dag_filelocs))
142 .execution_options(synchronize_session="fetch")
143 )
145 @classmethod
146 @provide_session
147 def has_dag(cls, fileloc: str, session: Session = NEW_SESSION) -> bool:
148 """Checks a file exist in dag_code table.
150 :param fileloc: the file to check
151 :param session: ORM Session
152 """
153 fileloc_hash = cls.dag_fileloc_hash(fileloc)
154 return session.query(literal(True)).filter(cls.fileloc_hash == fileloc_hash).one_or_none() is not None
156 @classmethod
157 def get_code_by_fileloc(cls, fileloc: str) -> str:
158 """Returns source code for a given fileloc.
160 :param fileloc: file path of a DAG
161 :return: source code as string
162 """
163 return cls.code(fileloc)
165 @classmethod
166 def code(cls, fileloc) -> str:
167 """Returns source code for this DagCode object.
169 :return: source code as string
170 """
171 return cls._get_code_from_db(fileloc)
173 @staticmethod
174 def _get_code_from_file(fileloc):
175 with open_maybe_zipped(fileloc, "r") as f:
176 code = f.read()
177 return code
179 @classmethod
180 @provide_session
181 def _get_code_from_db(cls, fileloc, session: Session = NEW_SESSION) -> str:
182 dag_code = session.query(cls).filter(cls.fileloc_hash == cls.dag_fileloc_hash(fileloc)).first()
183 if not dag_code:
184 raise DagCodeNotFound()
185 else:
186 code = dag_code.source_code
187 return code
189 @staticmethod
190 def dag_fileloc_hash(full_filepath: str) -> int:
191 """Hashing file location for indexing.
193 :param full_filepath: full filepath of DAG file
194 :return: hashed full_filepath
195 """
196 # Hashing is needed because the length of fileloc is 2000 as an Airflow convention,
197 # which is over the limit of indexing.
198 import hashlib
200 # Only 7 bytes because MySQL BigInteger can hold only 8 bytes (signed).
201 return struct.unpack(">Q", hashlib.sha1(full_filepath.encode("utf-8")).digest()[-8:])[0] >> 8