Coverage for /pythoncovmergedfiles/medio/medio/src/airflow/airflow/models/pool.py: 49%

115 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-06-07 06:35 +0000

1# 

2# Licensed to the Apache Software Foundation (ASF) under one 

3# or more contributor license agreements. See the NOTICE file 

4# distributed with this work for additional information 

5# regarding copyright ownership. The ASF licenses this file 

6# to you under the Apache License, Version 2.0 (the 

7# "License"); you may not use this file except in compliance 

8# with the License. You may obtain a copy of the License at 

9# 

10# http://www.apache.org/licenses/LICENSE-2.0 

11# 

12# Unless required by applicable law or agreed to in writing, 

13# software distributed under the License is distributed on an 

14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 

15# KIND, either express or implied. See the License for the 

16# specific language governing permissions and limitations 

17# under the License. 

18from __future__ import annotations 

19 

20from typing import Any, Iterable 

21 

22from sqlalchemy import Column, Integer, String, Text, func 

23from sqlalchemy.orm.session import Session 

24 

25from airflow.exceptions import AirflowException, PoolNotFound 

26from airflow.models.base import Base 

27from airflow.ti_deps.dependencies_states import EXECUTION_STATES 

28from airflow.typing_compat import TypedDict 

29from airflow.utils.session import NEW_SESSION, provide_session 

30from airflow.utils.sqlalchemy import nowait, with_row_locks 

31from airflow.utils.state import State 

32 

33 

34class PoolStats(TypedDict): 

35 """Dictionary containing Pool Stats.""" 

36 

37 total: int 

38 running: int 

39 queued: int 

40 open: int 

41 

42 

43class Pool(Base): 

44 """the class to get Pool info.""" 

45 

46 __tablename__ = "slot_pool" 

47 

48 id = Column(Integer, primary_key=True) 

49 pool = Column(String(256), unique=True) 

50 # -1 for infinite 

51 slots = Column(Integer, default=0) 

52 description = Column(Text) 

53 

54 DEFAULT_POOL_NAME = "default_pool" 

55 

56 def __repr__(self): 

57 return str(self.pool) 

58 

59 @staticmethod 

60 @provide_session 

61 def get_pools(session: Session = NEW_SESSION) -> list[Pool]: 

62 """Get all pools.""" 

63 return session.query(Pool).all() 

64 

65 @staticmethod 

66 @provide_session 

67 def get_pool(pool_name: str, session: Session = NEW_SESSION) -> Pool | None: 

68 """ 

69 Get the Pool with specific pool name from the Pools. 

70 

71 :param pool_name: The pool name of the Pool to get. 

72 :param session: SQLAlchemy ORM Session 

73 :return: the pool object 

74 """ 

75 return session.query(Pool).filter(Pool.pool == pool_name).first() 

76 

77 @staticmethod 

78 @provide_session 

79 def get_default_pool(session: Session = NEW_SESSION) -> Pool | None: 

80 """ 

81 Get the Pool of the default_pool from the Pools. 

82 

83 :param session: SQLAlchemy ORM Session 

84 :return: the pool object 

85 """ 

86 return Pool.get_pool(Pool.DEFAULT_POOL_NAME, session=session) 

87 

88 @staticmethod 

89 @provide_session 

90 def is_default_pool(id: int, session: Session = NEW_SESSION) -> bool: 

91 """ 

92 Check id if is the default_pool. 

93 

94 :param id: pool id 

95 :param session: SQLAlchemy ORM Session 

96 :return: True if id is default_pool, otherwise False 

97 """ 

98 return ( 

99 session.query(func.count(Pool.id)) 

100 .filter(Pool.id == id, Pool.pool == Pool.DEFAULT_POOL_NAME) 

101 .scalar() 

102 > 0 

103 ) 

104 

105 @staticmethod 

106 @provide_session 

107 def create_or_update_pool( 

108 name: str, 

109 slots: int, 

110 description: str, 

111 session: Session = NEW_SESSION, 

112 ) -> Pool: 

113 """Create a pool with given parameters or update it if it already exists.""" 

114 if not name: 

115 raise ValueError("Pool name must not be empty") 

116 

117 pool = session.query(Pool).filter_by(pool=name).one_or_none() 

118 if pool is None: 

119 pool = Pool(pool=name, slots=slots, description=description) 

120 session.add(pool) 

121 else: 

122 pool.slots = slots 

123 pool.description = description 

124 

125 session.commit() 

126 return pool 

127 

128 @staticmethod 

129 @provide_session 

130 def delete_pool(name: str, session: Session = NEW_SESSION) -> Pool: 

131 """Delete pool by a given name.""" 

132 if name == Pool.DEFAULT_POOL_NAME: 

133 raise AirflowException(f"{Pool.DEFAULT_POOL_NAME} cannot be deleted") 

134 

135 pool = session.query(Pool).filter_by(pool=name).first() 

136 if pool is None: 

137 raise PoolNotFound(f"Pool '{name}' doesn't exist") 

138 

139 session.delete(pool) 

140 session.commit() 

141 

142 return pool 

143 

144 @staticmethod 

145 @provide_session 

146 def slots_stats( 

147 *, 

148 lock_rows: bool = False, 

149 session: Session = NEW_SESSION, 

150 ) -> dict[str, PoolStats]: 

151 """ 

152 Get Pool stats (Number of Running, Queued, Open & Total tasks). 

153 

154 If ``lock_rows`` is True, and the database engine in use supports the ``NOWAIT`` syntax, then a 

155 non-blocking lock will be attempted -- if the lock is not available then SQLAlchemy will throw an 

156 OperationalError. 

157 

158 :param lock_rows: Should we attempt to obtain a row-level lock on all the Pool rows returns 

159 :param session: SQLAlchemy ORM Session 

160 """ 

161 from airflow.models.taskinstance import TaskInstance # Avoid circular import 

162 

163 pools: dict[str, PoolStats] = {} 

164 

165 query = session.query(Pool.pool, Pool.slots) 

166 

167 if lock_rows: 

168 query = with_row_locks(query, session=session, **nowait(session)) 

169 

170 pool_rows: Iterable[tuple[str, int]] = query.all() 

171 for (pool_name, total_slots) in pool_rows: 

172 if total_slots == -1: 

173 total_slots = float("inf") # type: ignore 

174 pools[pool_name] = PoolStats(total=total_slots, running=0, queued=0, open=0) 

175 

176 state_count_by_pool = ( 

177 session.query(TaskInstance.pool, TaskInstance.state, func.sum(TaskInstance.pool_slots)) 

178 .filter(TaskInstance.state.in_(list(EXECUTION_STATES))) 

179 .group_by(TaskInstance.pool, TaskInstance.state) 

180 ).all() 

181 

182 # calculate queued and running metrics 

183 for (pool_name, state, count) in state_count_by_pool: 

184 # Some databases return decimal.Decimal here. 

185 count = int(count) 

186 

187 stats_dict: PoolStats | None = pools.get(pool_name) 

188 if not stats_dict: 

189 continue 

190 # TypedDict key must be a string literal, so we use if-statements to set value 

191 if state == "running": 

192 stats_dict["running"] = count 

193 elif state == "queued": 

194 stats_dict["queued"] = count 

195 else: 

196 raise AirflowException(f"Unexpected state. Expected values: {EXECUTION_STATES}.") 

197 

198 # calculate open metric 

199 for pool_name, stats_dict in pools.items(): 

200 stats_dict["open"] = stats_dict["total"] - stats_dict["running"] - stats_dict["queued"] 

201 

202 return pools 

203 

204 def to_json(self) -> dict[str, Any]: 

205 """ 

206 Get the Pool in a json structure. 

207 

208 :return: the pool object in json format 

209 """ 

210 return { 

211 "id": self.id, 

212 "pool": self.pool, 

213 "slots": self.slots, 

214 "description": self.description, 

215 } 

216 

217 @provide_session 

218 def occupied_slots(self, session: Session = NEW_SESSION) -> int: 

219 """ 

220 Get the number of slots used by running/queued tasks at the moment. 

221 

222 :param session: SQLAlchemy ORM Session 

223 :return: the used number of slots 

224 """ 

225 from airflow.models.taskinstance import TaskInstance # Avoid circular import 

226 

227 return int( 

228 session.query(func.sum(TaskInstance.pool_slots)) 

229 .filter(TaskInstance.pool == self.pool) 

230 .filter(TaskInstance.state.in_(EXECUTION_STATES)) 

231 .scalar() 

232 or 0 

233 ) 

234 

235 @provide_session 

236 def running_slots(self, session: Session = NEW_SESSION) -> int: 

237 """ 

238 Get the number of slots used by running tasks at the moment. 

239 

240 :param session: SQLAlchemy ORM Session 

241 :return: the used number of slots 

242 """ 

243 from airflow.models.taskinstance import TaskInstance # Avoid circular import 

244 

245 return int( 

246 session.query(func.sum(TaskInstance.pool_slots)) 

247 .filter(TaskInstance.pool == self.pool) 

248 .filter(TaskInstance.state == State.RUNNING) 

249 .scalar() 

250 or 0 

251 ) 

252 

253 @provide_session 

254 def queued_slots(self, session: Session = NEW_SESSION) -> int: 

255 """ 

256 Get the number of slots used by queued tasks at the moment. 

257 

258 :param session: SQLAlchemy ORM Session 

259 :return: the used number of slots 

260 """ 

261 from airflow.models.taskinstance import TaskInstance # Avoid circular import 

262 

263 return int( 

264 session.query(func.sum(TaskInstance.pool_slots)) 

265 .filter(TaskInstance.pool == self.pool) 

266 .filter(TaskInstance.state == State.QUEUED) 

267 .scalar() 

268 or 0 

269 ) 

270 

271 @provide_session 

272 def scheduled_slots(self, session: Session = NEW_SESSION) -> int: 

273 """ 

274 Get the number of slots scheduled at the moment. 

275 

276 :param session: SQLAlchemy ORM Session 

277 :return: the number of scheduled slots 

278 """ 

279 from airflow.models.taskinstance import TaskInstance # Avoid circular import 

280 

281 return int( 

282 session.query(func.sum(TaskInstance.pool_slots)) 

283 .filter(TaskInstance.pool == self.pool) 

284 .filter(TaskInstance.state == State.SCHEDULED) 

285 .scalar() 

286 or 0 

287 ) 

288 

289 @provide_session 

290 def open_slots(self, session: Session = NEW_SESSION) -> float: 

291 """ 

292 Get the number of slots open at the moment. 

293 

294 :param session: SQLAlchemy ORM Session 

295 :return: the number of slots 

296 """ 

297 if self.slots == -1: 

298 return float("inf") 

299 return self.slots - self.occupied_slots(session)