Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/airflow/models/dagcode.py: 47%
96 statements
« prev ^ index » next coverage.py v7.0.1, created at 2022-12-25 06:11 +0000
« prev ^ index » next coverage.py v7.0.1, created at 2022-12-25 06:11 +0000
1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements. See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership. The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License. You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied. See the License for the
15# specific language governing permissions and limitations
16# under the License.
17from __future__ import annotations
19import logging
20import os
21import struct
22from datetime import datetime
23from typing import Iterable
25from sqlalchemy import BigInteger, Column, String, Text
26from sqlalchemy.dialects.mysql import MEDIUMTEXT
27from sqlalchemy.sql.expression import literal
29from airflow.exceptions import AirflowException, DagCodeNotFound
30from airflow.models.base import Base
31from airflow.utils import timezone
32from airflow.utils.file import correct_maybe_zipped, open_maybe_zipped
33from airflow.utils.session import provide_session
34from airflow.utils.sqlalchemy import UtcDateTime
36log = logging.getLogger(__name__)
39class DagCode(Base):
40 """A table for DAGs code.
42 dag_code table contains code of DAG files synchronized by scheduler.
44 For details on dag serialization see SerializedDagModel
45 """
47 __tablename__ = "dag_code"
49 fileloc_hash = Column(BigInteger, nullable=False, primary_key=True, autoincrement=False)
50 fileloc = Column(String(2000), nullable=False)
51 # The max length of fileloc exceeds the limit of indexing.
52 last_updated = Column(UtcDateTime, nullable=False)
53 source_code = Column(Text().with_variant(MEDIUMTEXT(), "mysql"), nullable=False)
55 def __init__(self, full_filepath: str, source_code: str | None = None):
56 self.fileloc = full_filepath
57 self.fileloc_hash = DagCode.dag_fileloc_hash(self.fileloc)
58 self.last_updated = timezone.utcnow()
59 self.source_code = source_code or DagCode.code(self.fileloc)
61 @provide_session
62 def sync_to_db(self, session=None):
63 """Writes code into database.
65 :param session: ORM Session
66 """
67 self.bulk_sync_to_db([self.fileloc], session)
69 @classmethod
70 @provide_session
71 def bulk_sync_to_db(cls, filelocs: Iterable[str], session=None):
72 """Writes code in bulk into database.
74 :param filelocs: file paths of DAGs to sync
75 :param session: ORM Session
76 """
77 filelocs = set(filelocs)
78 filelocs_to_hashes = {fileloc: DagCode.dag_fileloc_hash(fileloc) for fileloc in filelocs}
79 existing_orm_dag_codes = (
80 session.query(DagCode)
81 .filter(DagCode.fileloc_hash.in_(filelocs_to_hashes.values()))
82 .with_for_update(of=DagCode)
83 .all()
84 )
86 if existing_orm_dag_codes:
87 existing_orm_dag_codes_map = {
88 orm_dag_code.fileloc: orm_dag_code for orm_dag_code in existing_orm_dag_codes
89 }
90 else:
91 existing_orm_dag_codes_map = {}
93 existing_orm_dag_codes_by_fileloc_hashes = {orm.fileloc_hash: orm for orm in existing_orm_dag_codes}
94 existing_orm_filelocs = {orm.fileloc for orm in existing_orm_dag_codes_by_fileloc_hashes.values()}
95 if not existing_orm_filelocs.issubset(filelocs):
96 conflicting_filelocs = existing_orm_filelocs.difference(filelocs)
97 hashes_to_filelocs = {DagCode.dag_fileloc_hash(fileloc): fileloc for fileloc in filelocs}
98 message = ""
99 for fileloc in conflicting_filelocs:
100 filename = hashes_to_filelocs[DagCode.dag_fileloc_hash(fileloc)]
101 message += (
102 f"Filename '{filename}' causes a hash collision in the "
103 f"database with '{fileloc}'. Please rename the file."
104 )
105 raise AirflowException(message)
107 existing_filelocs = {dag_code.fileloc for dag_code in existing_orm_dag_codes}
108 missing_filelocs = filelocs.difference(existing_filelocs)
110 for fileloc in missing_filelocs:
111 orm_dag_code = DagCode(fileloc, cls._get_code_from_file(fileloc))
112 session.add(orm_dag_code)
114 for fileloc in existing_filelocs:
115 current_version = existing_orm_dag_codes_by_fileloc_hashes[filelocs_to_hashes[fileloc]]
116 file_mod_time = datetime.fromtimestamp(
117 os.path.getmtime(correct_maybe_zipped(fileloc)), tz=timezone.utc
118 )
120 if file_mod_time > current_version.last_updated:
121 orm_dag_code = existing_orm_dag_codes_map[fileloc]
122 orm_dag_code.last_updated = file_mod_time
123 orm_dag_code.source_code = cls._get_code_from_file(orm_dag_code.fileloc)
124 session.merge(orm_dag_code)
126 @classmethod
127 @provide_session
128 def remove_deleted_code(cls, alive_dag_filelocs: list[str], session=None):
129 """Deletes code not included in alive_dag_filelocs.
131 :param alive_dag_filelocs: file paths of alive DAGs
132 :param session: ORM Session
133 """
134 alive_fileloc_hashes = [cls.dag_fileloc_hash(fileloc) for fileloc in alive_dag_filelocs]
136 log.debug("Deleting code from %s table ", cls.__tablename__)
138 session.query(cls).filter(
139 cls.fileloc_hash.notin_(alive_fileloc_hashes), cls.fileloc.notin_(alive_dag_filelocs)
140 ).delete(synchronize_session="fetch")
142 @classmethod
143 @provide_session
144 def has_dag(cls, fileloc: str, session=None) -> bool:
145 """Checks a file exist in dag_code table.
147 :param fileloc: the file to check
148 :param session: ORM Session
149 """
150 fileloc_hash = cls.dag_fileloc_hash(fileloc)
151 return session.query(literal(True)).filter(cls.fileloc_hash == fileloc_hash).one_or_none() is not None
153 @classmethod
154 def get_code_by_fileloc(cls, fileloc: str) -> str:
155 """Returns source code for a given fileloc.
157 :param fileloc: file path of a DAG
158 :return: source code as string
159 """
160 return cls.code(fileloc)
162 @classmethod
163 def code(cls, fileloc) -> str:
164 """Returns source code for this DagCode object.
166 :return: source code as string
167 """
168 return cls._get_code_from_db(fileloc)
170 @staticmethod
171 def _get_code_from_file(fileloc):
172 with open_maybe_zipped(fileloc, "r") as f:
173 code = f.read()
174 return code
176 @classmethod
177 @provide_session
178 def _get_code_from_db(cls, fileloc, session=None):
179 dag_code = session.query(cls).filter(cls.fileloc_hash == cls.dag_fileloc_hash(fileloc)).first()
180 if not dag_code:
181 raise DagCodeNotFound()
182 else:
183 code = dag_code.source_code
184 return code
186 @staticmethod
187 def dag_fileloc_hash(full_filepath: str) -> int:
188 """Hashing file location for indexing.
190 :param full_filepath: full filepath of DAG file
191 :return: hashed full_filepath
192 """
193 # Hashing is needed because the length of fileloc is 2000 as an Airflow convention,
194 # which is over the limit of indexing.
195 import hashlib
197 # Only 7 bytes because MySQL BigInteger can hold only 8 bytes (signed).
198 return struct.unpack(">Q", hashlib.sha1(full_filepath.encode("utf-8")).digest()[-8:])[0] >> 8