Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/airflow/models/dagcode.py: 47%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

99 statements  

1# Licensed to the Apache Software Foundation (ASF) under one 

2# or more contributor license agreements. See the NOTICE file 

3# distributed with this work for additional information 

4# regarding copyright ownership. The ASF licenses this file 

5# to you under the Apache License, Version 2.0 (the 

6# "License"); you may not use this file except in compliance 

7# with the License. You may obtain a copy of the License at 

8# 

9# http://www.apache.org/licenses/LICENSE-2.0 

10# 

11# Unless required by applicable law or agreed to in writing, 

12# software distributed under the License is distributed on an 

13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 

14# KIND, either express or implied. See the License for the 

15# specific language governing permissions and limitations 

16# under the License. 

17from __future__ import annotations 

18 

19import logging 

20import os 

21import struct 

22from datetime import datetime 

23from typing import TYPE_CHECKING, Collection, Iterable 

24 

25from sqlalchemy import BigInteger, Column, String, Text, delete, select 

26from sqlalchemy.dialects.mysql import MEDIUMTEXT 

27from sqlalchemy.sql.expression import literal 

28 

29from airflow.exceptions import AirflowException, DagCodeNotFound 

30from airflow.models.base import Base 

31from airflow.utils import timezone 

32from airflow.utils.file import correct_maybe_zipped, open_maybe_zipped 

33from airflow.utils.session import NEW_SESSION, provide_session 

34from airflow.utils.sqlalchemy import UtcDateTime 

35 

36if TYPE_CHECKING: 

37 from sqlalchemy.orm import Session 

38 

39log = logging.getLogger(__name__) 

40 

41 

42class DagCode(Base): 

43 """A table for DAGs code. 

44 

45 dag_code table contains code of DAG files synchronized by scheduler. 

46 

47 For details on dag serialization see SerializedDagModel 

48 """ 

49 

50 __tablename__ = "dag_code" 

51 

52 fileloc_hash = Column(BigInteger, nullable=False, primary_key=True, autoincrement=False) 

53 fileloc = Column(String(2000), nullable=False) 

54 # The max length of fileloc exceeds the limit of indexing. 

55 last_updated = Column(UtcDateTime, nullable=False) 

56 source_code = Column(Text().with_variant(MEDIUMTEXT(), "mysql"), nullable=False) 

57 

58 def __init__(self, full_filepath: str, source_code: str | None = None): 

59 self.fileloc = full_filepath 

60 self.fileloc_hash = DagCode.dag_fileloc_hash(self.fileloc) 

61 self.last_updated = timezone.utcnow() 

62 self.source_code = source_code or DagCode.code(self.fileloc) 

63 

64 @provide_session 

65 def sync_to_db(self, session: Session = NEW_SESSION) -> None: 

66 """Write code into database. 

67 

68 :param session: ORM Session 

69 """ 

70 self.bulk_sync_to_db([self.fileloc], session) 

71 

72 @classmethod 

73 @provide_session 

74 def bulk_sync_to_db(cls, filelocs: Iterable[str], session: Session = NEW_SESSION) -> None: 

75 """Write code in bulk into database. 

76 

77 :param filelocs: file paths of DAGs to sync 

78 :param session: ORM Session 

79 """ 

80 filelocs = set(filelocs) 

81 filelocs_to_hashes = {fileloc: DagCode.dag_fileloc_hash(fileloc) for fileloc in filelocs} 

82 existing_orm_dag_codes = session.scalars( 

83 select(DagCode) 

84 .filter(DagCode.fileloc_hash.in_(filelocs_to_hashes.values())) 

85 .with_for_update(of=DagCode) 

86 ).all() 

87 

88 if existing_orm_dag_codes: 

89 existing_orm_dag_codes_map = { 

90 orm_dag_code.fileloc: orm_dag_code for orm_dag_code in existing_orm_dag_codes 

91 } 

92 else: 

93 existing_orm_dag_codes_map = {} 

94 

95 existing_orm_dag_codes_by_fileloc_hashes = {orm.fileloc_hash: orm for orm in existing_orm_dag_codes} 

96 existing_orm_filelocs = {orm.fileloc for orm in existing_orm_dag_codes_by_fileloc_hashes.values()} 

97 if not existing_orm_filelocs.issubset(filelocs): 

98 conflicting_filelocs = existing_orm_filelocs.difference(filelocs) 

99 hashes_to_filelocs = {DagCode.dag_fileloc_hash(fileloc): fileloc for fileloc in filelocs} 

100 message = "" 

101 for fileloc in conflicting_filelocs: 

102 filename = hashes_to_filelocs[DagCode.dag_fileloc_hash(fileloc)] 

103 message += ( 

104 f"Filename '{filename}' causes a hash collision in the " 

105 f"database with '{fileloc}'. Please rename the file." 

106 ) 

107 raise AirflowException(message) 

108 

109 existing_filelocs = {dag_code.fileloc for dag_code in existing_orm_dag_codes} 

110 missing_filelocs = filelocs.difference(existing_filelocs) 

111 

112 for fileloc in missing_filelocs: 

113 orm_dag_code = DagCode(fileloc, cls._get_code_from_file(fileloc)) 

114 session.add(orm_dag_code) 

115 

116 for fileloc in existing_filelocs: 

117 current_version = existing_orm_dag_codes_by_fileloc_hashes[filelocs_to_hashes[fileloc]] 

118 file_mod_time = datetime.fromtimestamp( 

119 os.path.getmtime(correct_maybe_zipped(fileloc)), tz=timezone.utc 

120 ) 

121 

122 if file_mod_time > current_version.last_updated: 

123 orm_dag_code = existing_orm_dag_codes_map[fileloc] 

124 orm_dag_code.last_updated = file_mod_time 

125 orm_dag_code.source_code = cls._get_code_from_file(orm_dag_code.fileloc) 

126 session.merge(orm_dag_code) 

127 

128 @classmethod 

129 @provide_session 

130 def remove_deleted_code( 

131 cls, 

132 alive_dag_filelocs: Collection[str], 

133 processor_subdir: str, 

134 session: Session = NEW_SESSION, 

135 ) -> None: 

136 """Delete code not included in alive_dag_filelocs. 

137 

138 :param alive_dag_filelocs: file paths of alive DAGs 

139 :param processor_subdir: dag processor subdir 

140 :param session: ORM Session 

141 """ 

142 alive_fileloc_hashes = [cls.dag_fileloc_hash(fileloc) for fileloc in alive_dag_filelocs] 

143 

144 log.debug("Deleting code from %s table ", cls.__tablename__) 

145 

146 session.execute( 

147 delete(cls) 

148 .where( 

149 cls.fileloc_hash.notin_(alive_fileloc_hashes), 

150 cls.fileloc.notin_(alive_dag_filelocs), 

151 cls.fileloc.contains(processor_subdir), 

152 ) 

153 .execution_options(synchronize_session="fetch") 

154 ) 

155 

156 @classmethod 

157 @provide_session 

158 def has_dag(cls, fileloc: str, session: Session = NEW_SESSION) -> bool: 

159 """Check a file exist in dag_code table. 

160 

161 :param fileloc: the file to check 

162 :param session: ORM Session 

163 """ 

164 fileloc_hash = cls.dag_fileloc_hash(fileloc) 

165 return ( 

166 session.scalars(select(literal(True)).where(cls.fileloc_hash == fileloc_hash)).one_or_none() 

167 is not None 

168 ) 

169 

170 @classmethod 

171 def get_code_by_fileloc(cls, fileloc: str) -> str: 

172 """Return source code for a given fileloc. 

173 

174 :param fileloc: file path of a DAG 

175 :return: source code as string 

176 """ 

177 return cls.code(fileloc) 

178 

179 @classmethod 

180 @provide_session 

181 def code(cls, fileloc, session: Session = NEW_SESSION) -> str: 

182 """Return source code for this DagCode object. 

183 

184 :return: source code as string 

185 """ 

186 return cls._get_code_from_db(fileloc, session) 

187 

188 @staticmethod 

189 def _get_code_from_file(fileloc): 

190 with open_maybe_zipped(fileloc, "r") as f: 

191 code = f.read() 

192 return code 

193 

194 @classmethod 

195 @provide_session 

196 def _get_code_from_db(cls, fileloc, session: Session = NEW_SESSION) -> str: 

197 dag_code = session.scalar(select(cls).where(cls.fileloc_hash == cls.dag_fileloc_hash(fileloc))) 

198 if not dag_code: 

199 raise DagCodeNotFound() 

200 else: 

201 code = dag_code.source_code 

202 return code 

203 

204 @staticmethod 

205 def dag_fileloc_hash(full_filepath: str) -> int: 

206 """Hashing file location for indexing. 

207 

208 :param full_filepath: full filepath of DAG file 

209 :return: hashed full_filepath 

210 """ 

211 # Hashing is needed because the length of fileloc is 2000 as an Airflow convention, 

212 # which is over the limit of indexing. 

213 import hashlib 

214 

215 # Only 7 bytes because MySQL BigInteger can hold only 8 bytes (signed). 

216 return struct.unpack(">Q", hashlib.sha1(full_filepath.encode("utf-8")).digest()[-8:])[0] >> 8