Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/airflow/models/dagcode.py: 47%

96 statements  

« prev     ^ index     » next       coverage.py v7.0.1, created at 2022-12-25 06:11 +0000

1# Licensed to the Apache Software Foundation (ASF) under one 

2# or more contributor license agreements. See the NOTICE file 

3# distributed with this work for additional information 

4# regarding copyright ownership. The ASF licenses this file 

5# to you under the Apache License, Version 2.0 (the 

6# "License"); you may not use this file except in compliance 

7# with the License. You may obtain a copy of the License at 

8# 

9# http://www.apache.org/licenses/LICENSE-2.0 

10# 

11# Unless required by applicable law or agreed to in writing, 

12# software distributed under the License is distributed on an 

13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 

14# KIND, either express or implied. See the License for the 

15# specific language governing permissions and limitations 

16# under the License. 

17from __future__ import annotations 

18 

19import logging 

20import os 

21import struct 

22from datetime import datetime 

23from typing import Iterable 

24 

25from sqlalchemy import BigInteger, Column, String, Text 

26from sqlalchemy.dialects.mysql import MEDIUMTEXT 

27from sqlalchemy.sql.expression import literal 

28 

29from airflow.exceptions import AirflowException, DagCodeNotFound 

30from airflow.models.base import Base 

31from airflow.utils import timezone 

32from airflow.utils.file import correct_maybe_zipped, open_maybe_zipped 

33from airflow.utils.session import provide_session 

34from airflow.utils.sqlalchemy import UtcDateTime 

35 

36log = logging.getLogger(__name__) 

37 

38 

39class DagCode(Base): 

40 """A table for DAGs code. 

41 

42 dag_code table contains code of DAG files synchronized by scheduler. 

43 

44 For details on dag serialization see SerializedDagModel 

45 """ 

46 

47 __tablename__ = "dag_code" 

48 

49 fileloc_hash = Column(BigInteger, nullable=False, primary_key=True, autoincrement=False) 

50 fileloc = Column(String(2000), nullable=False) 

51 # The max length of fileloc exceeds the limit of indexing. 

52 last_updated = Column(UtcDateTime, nullable=False) 

53 source_code = Column(Text().with_variant(MEDIUMTEXT(), "mysql"), nullable=False) 

54 

55 def __init__(self, full_filepath: str, source_code: str | None = None): 

56 self.fileloc = full_filepath 

57 self.fileloc_hash = DagCode.dag_fileloc_hash(self.fileloc) 

58 self.last_updated = timezone.utcnow() 

59 self.source_code = source_code or DagCode.code(self.fileloc) 

60 

61 @provide_session 

62 def sync_to_db(self, session=None): 

63 """Writes code into database. 

64 

65 :param session: ORM Session 

66 """ 

67 self.bulk_sync_to_db([self.fileloc], session) 

68 

69 @classmethod 

70 @provide_session 

71 def bulk_sync_to_db(cls, filelocs: Iterable[str], session=None): 

72 """Writes code in bulk into database. 

73 

74 :param filelocs: file paths of DAGs to sync 

75 :param session: ORM Session 

76 """ 

77 filelocs = set(filelocs) 

78 filelocs_to_hashes = {fileloc: DagCode.dag_fileloc_hash(fileloc) for fileloc in filelocs} 

79 existing_orm_dag_codes = ( 

80 session.query(DagCode) 

81 .filter(DagCode.fileloc_hash.in_(filelocs_to_hashes.values())) 

82 .with_for_update(of=DagCode) 

83 .all() 

84 ) 

85 

86 if existing_orm_dag_codes: 

87 existing_orm_dag_codes_map = { 

88 orm_dag_code.fileloc: orm_dag_code for orm_dag_code in existing_orm_dag_codes 

89 } 

90 else: 

91 existing_orm_dag_codes_map = {} 

92 

93 existing_orm_dag_codes_by_fileloc_hashes = {orm.fileloc_hash: orm for orm in existing_orm_dag_codes} 

94 existing_orm_filelocs = {orm.fileloc for orm in existing_orm_dag_codes_by_fileloc_hashes.values()} 

95 if not existing_orm_filelocs.issubset(filelocs): 

96 conflicting_filelocs = existing_orm_filelocs.difference(filelocs) 

97 hashes_to_filelocs = {DagCode.dag_fileloc_hash(fileloc): fileloc for fileloc in filelocs} 

98 message = "" 

99 for fileloc in conflicting_filelocs: 

100 filename = hashes_to_filelocs[DagCode.dag_fileloc_hash(fileloc)] 

101 message += ( 

102 f"Filename '{filename}' causes a hash collision in the " 

103 f"database with '{fileloc}'. Please rename the file." 

104 ) 

105 raise AirflowException(message) 

106 

107 existing_filelocs = {dag_code.fileloc for dag_code in existing_orm_dag_codes} 

108 missing_filelocs = filelocs.difference(existing_filelocs) 

109 

110 for fileloc in missing_filelocs: 

111 orm_dag_code = DagCode(fileloc, cls._get_code_from_file(fileloc)) 

112 session.add(orm_dag_code) 

113 

114 for fileloc in existing_filelocs: 

115 current_version = existing_orm_dag_codes_by_fileloc_hashes[filelocs_to_hashes[fileloc]] 

116 file_mod_time = datetime.fromtimestamp( 

117 os.path.getmtime(correct_maybe_zipped(fileloc)), tz=timezone.utc 

118 ) 

119 

120 if file_mod_time > current_version.last_updated: 

121 orm_dag_code = existing_orm_dag_codes_map[fileloc] 

122 orm_dag_code.last_updated = file_mod_time 

123 orm_dag_code.source_code = cls._get_code_from_file(orm_dag_code.fileloc) 

124 session.merge(orm_dag_code) 

125 

126 @classmethod 

127 @provide_session 

128 def remove_deleted_code(cls, alive_dag_filelocs: list[str], session=None): 

129 """Deletes code not included in alive_dag_filelocs. 

130 

131 :param alive_dag_filelocs: file paths of alive DAGs 

132 :param session: ORM Session 

133 """ 

134 alive_fileloc_hashes = [cls.dag_fileloc_hash(fileloc) for fileloc in alive_dag_filelocs] 

135 

136 log.debug("Deleting code from %s table ", cls.__tablename__) 

137 

138 session.query(cls).filter( 

139 cls.fileloc_hash.notin_(alive_fileloc_hashes), cls.fileloc.notin_(alive_dag_filelocs) 

140 ).delete(synchronize_session="fetch") 

141 

142 @classmethod 

143 @provide_session 

144 def has_dag(cls, fileloc: str, session=None) -> bool: 

145 """Checks a file exist in dag_code table. 

146 

147 :param fileloc: the file to check 

148 :param session: ORM Session 

149 """ 

150 fileloc_hash = cls.dag_fileloc_hash(fileloc) 

151 return session.query(literal(True)).filter(cls.fileloc_hash == fileloc_hash).one_or_none() is not None 

152 

153 @classmethod 

154 def get_code_by_fileloc(cls, fileloc: str) -> str: 

155 """Returns source code for a given fileloc. 

156 

157 :param fileloc: file path of a DAG 

158 :return: source code as string 

159 """ 

160 return cls.code(fileloc) 

161 

162 @classmethod 

163 def code(cls, fileloc) -> str: 

164 """Returns source code for this DagCode object. 

165 

166 :return: source code as string 

167 """ 

168 return cls._get_code_from_db(fileloc) 

169 

170 @staticmethod 

171 def _get_code_from_file(fileloc): 

172 with open_maybe_zipped(fileloc, "r") as f: 

173 code = f.read() 

174 return code 

175 

176 @classmethod 

177 @provide_session 

178 def _get_code_from_db(cls, fileloc, session=None): 

179 dag_code = session.query(cls).filter(cls.fileloc_hash == cls.dag_fileloc_hash(fileloc)).first() 

180 if not dag_code: 

181 raise DagCodeNotFound() 

182 else: 

183 code = dag_code.source_code 

184 return code 

185 

186 @staticmethod 

187 def dag_fileloc_hash(full_filepath: str) -> int: 

188 """Hashing file location for indexing. 

189 

190 :param full_filepath: full filepath of DAG file 

191 :return: hashed full_filepath 

192 """ 

193 # Hashing is needed because the length of fileloc is 2000 as an Airflow convention, 

194 # which is over the limit of indexing. 

195 import hashlib 

196 

197 # Only 7 bytes because MySQL BigInteger can hold only 8 bytes (signed). 

198 return struct.unpack(">Q", hashlib.sha1(full_filepath.encode("utf-8")).digest()[-8:])[0] >> 8