Coverage for /pythoncovmergedfiles/medio/medio/src/airflow/build/lib/airflow/models/dagcode.py: 47%

97 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-06-07 06:35 +0000

1# Licensed to the Apache Software Foundation (ASF) under one 

2# or more contributor license agreements. See the NOTICE file 

3# distributed with this work for additional information 

4# regarding copyright ownership. The ASF licenses this file 

5# to you under the Apache License, Version 2.0 (the 

6# "License"); you may not use this file except in compliance 

7# with the License. You may obtain a copy of the License at 

8# 

9# http://www.apache.org/licenses/LICENSE-2.0 

10# 

11# Unless required by applicable law or agreed to in writing, 

12# software distributed under the License is distributed on an 

13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 

14# KIND, either express or implied. See the License for the 

15# specific language governing permissions and limitations 

16# under the License. 

17from __future__ import annotations 

18 

19import logging 

20import os 

21import struct 

22from datetime import datetime 

23from typing import Iterable 

24 

25from sqlalchemy import BigInteger, Column, String, Text, delete 

26from sqlalchemy.dialects.mysql import MEDIUMTEXT 

27from sqlalchemy.orm import Session 

28from sqlalchemy.sql.expression import literal 

29 

30from airflow.exceptions import AirflowException, DagCodeNotFound 

31from airflow.models.base import Base 

32from airflow.utils import timezone 

33from airflow.utils.file import correct_maybe_zipped, open_maybe_zipped 

34from airflow.utils.session import NEW_SESSION, provide_session 

35from airflow.utils.sqlalchemy import UtcDateTime 

36 

37log = logging.getLogger(__name__) 

38 

39 

40class DagCode(Base): 

41 """A table for DAGs code. 

42 

43 dag_code table contains code of DAG files synchronized by scheduler. 

44 

45 For details on dag serialization see SerializedDagModel 

46 """ 

47 

48 __tablename__ = "dag_code" 

49 

50 fileloc_hash = Column(BigInteger, nullable=False, primary_key=True, autoincrement=False) 

51 fileloc = Column(String(2000), nullable=False) 

52 # The max length of fileloc exceeds the limit of indexing. 

53 last_updated = Column(UtcDateTime, nullable=False) 

54 source_code = Column(Text().with_variant(MEDIUMTEXT(), "mysql"), nullable=False) 

55 

56 def __init__(self, full_filepath: str, source_code: str | None = None): 

57 self.fileloc = full_filepath 

58 self.fileloc_hash = DagCode.dag_fileloc_hash(self.fileloc) 

59 self.last_updated = timezone.utcnow() 

60 self.source_code = source_code or DagCode.code(self.fileloc) 

61 

62 @provide_session 

63 def sync_to_db(self, session: Session = NEW_SESSION) -> None: 

64 """Writes code into database. 

65 

66 :param session: ORM Session 

67 """ 

68 self.bulk_sync_to_db([self.fileloc], session) 

69 

70 @classmethod 

71 @provide_session 

72 def bulk_sync_to_db(cls, filelocs: Iterable[str], session: Session = NEW_SESSION) -> None: 

73 """Writes code in bulk into database. 

74 

75 :param filelocs: file paths of DAGs to sync 

76 :param session: ORM Session 

77 """ 

78 filelocs = set(filelocs) 

79 filelocs_to_hashes = {fileloc: DagCode.dag_fileloc_hash(fileloc) for fileloc in filelocs} 

80 existing_orm_dag_codes = ( 

81 session.query(DagCode) 

82 .filter(DagCode.fileloc_hash.in_(filelocs_to_hashes.values())) 

83 .with_for_update(of=DagCode) 

84 .all() 

85 ) 

86 

87 if existing_orm_dag_codes: 

88 existing_orm_dag_codes_map = { 

89 orm_dag_code.fileloc: orm_dag_code for orm_dag_code in existing_orm_dag_codes 

90 } 

91 else: 

92 existing_orm_dag_codes_map = {} 

93 

94 existing_orm_dag_codes_by_fileloc_hashes = {orm.fileloc_hash: orm for orm in existing_orm_dag_codes} 

95 existing_orm_filelocs = {orm.fileloc for orm in existing_orm_dag_codes_by_fileloc_hashes.values()} 

96 if not existing_orm_filelocs.issubset(filelocs): 

97 conflicting_filelocs = existing_orm_filelocs.difference(filelocs) 

98 hashes_to_filelocs = {DagCode.dag_fileloc_hash(fileloc): fileloc for fileloc in filelocs} 

99 message = "" 

100 for fileloc in conflicting_filelocs: 

101 filename = hashes_to_filelocs[DagCode.dag_fileloc_hash(fileloc)] 

102 message += ( 

103 f"Filename '{filename}' causes a hash collision in the " 

104 f"database with '{fileloc}'. Please rename the file." 

105 ) 

106 raise AirflowException(message) 

107 

108 existing_filelocs = {dag_code.fileloc for dag_code in existing_orm_dag_codes} 

109 missing_filelocs = filelocs.difference(existing_filelocs) 

110 

111 for fileloc in missing_filelocs: 

112 orm_dag_code = DagCode(fileloc, cls._get_code_from_file(fileloc)) 

113 session.add(orm_dag_code) 

114 

115 for fileloc in existing_filelocs: 

116 current_version = existing_orm_dag_codes_by_fileloc_hashes[filelocs_to_hashes[fileloc]] 

117 file_mod_time = datetime.fromtimestamp( 

118 os.path.getmtime(correct_maybe_zipped(fileloc)), tz=timezone.utc 

119 ) 

120 

121 if file_mod_time > current_version.last_updated: 

122 orm_dag_code = existing_orm_dag_codes_map[fileloc] 

123 orm_dag_code.last_updated = file_mod_time 

124 orm_dag_code.source_code = cls._get_code_from_file(orm_dag_code.fileloc) 

125 session.merge(orm_dag_code) 

126 

127 @classmethod 

128 @provide_session 

129 def remove_deleted_code(cls, alive_dag_filelocs: list[str], session: Session = NEW_SESSION) -> None: 

130 """Deletes code not included in alive_dag_filelocs. 

131 

132 :param alive_dag_filelocs: file paths of alive DAGs 

133 :param session: ORM Session 

134 """ 

135 alive_fileloc_hashes = [cls.dag_fileloc_hash(fileloc) for fileloc in alive_dag_filelocs] 

136 

137 log.debug("Deleting code from %s table ", cls.__tablename__) 

138 

139 session.execute( 

140 delete(cls) 

141 .where(cls.fileloc_hash.notin_(alive_fileloc_hashes), cls.fileloc.notin_(alive_dag_filelocs)) 

142 .execution_options(synchronize_session="fetch") 

143 ) 

144 

145 @classmethod 

146 @provide_session 

147 def has_dag(cls, fileloc: str, session: Session = NEW_SESSION) -> bool: 

148 """Checks a file exist in dag_code table. 

149 

150 :param fileloc: the file to check 

151 :param session: ORM Session 

152 """ 

153 fileloc_hash = cls.dag_fileloc_hash(fileloc) 

154 return session.query(literal(True)).filter(cls.fileloc_hash == fileloc_hash).one_or_none() is not None 

155 

156 @classmethod 

157 def get_code_by_fileloc(cls, fileloc: str) -> str: 

158 """Returns source code for a given fileloc. 

159 

160 :param fileloc: file path of a DAG 

161 :return: source code as string 

162 """ 

163 return cls.code(fileloc) 

164 

165 @classmethod 

166 def code(cls, fileloc) -> str: 

167 """Returns source code for this DagCode object. 

168 

169 :return: source code as string 

170 """ 

171 return cls._get_code_from_db(fileloc) 

172 

173 @staticmethod 

174 def _get_code_from_file(fileloc): 

175 with open_maybe_zipped(fileloc, "r") as f: 

176 code = f.read() 

177 return code 

178 

179 @classmethod 

180 @provide_session 

181 def _get_code_from_db(cls, fileloc, session: Session = NEW_SESSION) -> str: 

182 dag_code = session.query(cls).filter(cls.fileloc_hash == cls.dag_fileloc_hash(fileloc)).first() 

183 if not dag_code: 

184 raise DagCodeNotFound() 

185 else: 

186 code = dag_code.source_code 

187 return code 

188 

189 @staticmethod 

190 def dag_fileloc_hash(full_filepath: str) -> int: 

191 """Hashing file location for indexing. 

192 

193 :param full_filepath: full filepath of DAG file 

194 :return: hashed full_filepath 

195 """ 

196 # Hashing is needed because the length of fileloc is 2000 as an Airflow convention, 

197 # which is over the limit of indexing. 

198 import hashlib 

199 

200 # Only 7 bytes because MySQL BigInteger can hold only 8 bytes (signed). 

201 return struct.unpack(">Q", hashlib.sha1(full_filepath.encode("utf-8")).digest()[-8:])[0] >> 8