Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/airflow/models/pool.py: 45%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

139 statements  

1# 

2# Licensed to the Apache Software Foundation (ASF) under one 

3# or more contributor license agreements. See the NOTICE file 

4# distributed with this work for additional information 

5# regarding copyright ownership. The ASF licenses this file 

6# to you under the Apache License, Version 2.0 (the 

7# "License"); you may not use this file except in compliance 

8# with the License. You may obtain a copy of the License at 

9# 

10# http://www.apache.org/licenses/LICENSE-2.0 

11# 

12# Unless required by applicable law or agreed to in writing, 

13# software distributed under the License is distributed on an 

14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 

15# KIND, either express or implied. See the License for the 

16# specific language governing permissions and limitations 

17# under the License. 

18from __future__ import annotations 

19 

20from typing import TYPE_CHECKING, Any 

21 

22from sqlalchemy import Boolean, Column, Integer, String, Text, func, select 

23 

24from airflow.exceptions import AirflowException, PoolNotFound 

25from airflow.models.base import Base 

26from airflow.ti_deps.dependencies_states import EXECUTION_STATES 

27from airflow.typing_compat import TypedDict 

28from airflow.utils.db import exists_query 

29from airflow.utils.session import NEW_SESSION, provide_session 

30from airflow.utils.sqlalchemy import with_row_locks 

31from airflow.utils.state import TaskInstanceState 

32 

33if TYPE_CHECKING: 

34 from sqlalchemy.orm.session import Session 

35 

36 

37class PoolStats(TypedDict): 

38 """Dictionary containing Pool Stats.""" 

39 

40 total: int 

41 running: int 

42 deferred: int 

43 queued: int 

44 open: int 

45 scheduled: int 

46 

47 

48class Pool(Base): 

49 """the class to get Pool info.""" 

50 

51 __tablename__ = "slot_pool" 

52 

53 id = Column(Integer, primary_key=True) 

54 pool = Column(String(256), unique=True) 

55 # -1 for infinite 

56 slots = Column(Integer, default=0) 

57 description = Column(Text) 

58 include_deferred = Column(Boolean, nullable=False) 

59 

60 DEFAULT_POOL_NAME = "default_pool" 

61 

62 def __repr__(self): 

63 return str(self.pool) 

64 

65 @staticmethod 

66 @provide_session 

67 def get_pools(session: Session = NEW_SESSION) -> list[Pool]: 

68 """Get all pools.""" 

69 return session.scalars(select(Pool)).all() 

70 

71 @staticmethod 

72 @provide_session 

73 def get_pool(pool_name: str, session: Session = NEW_SESSION) -> Pool | None: 

74 """ 

75 Get the Pool with specific pool name from the Pools. 

76 

77 :param pool_name: The pool name of the Pool to get. 

78 :param session: SQLAlchemy ORM Session 

79 :return: the pool object 

80 """ 

81 return session.scalar(select(Pool).where(Pool.pool == pool_name)) 

82 

83 @staticmethod 

84 @provide_session 

85 def get_default_pool(session: Session = NEW_SESSION) -> Pool | None: 

86 """ 

87 Get the Pool of the default_pool from the Pools. 

88 

89 :param session: SQLAlchemy ORM Session 

90 :return: the pool object 

91 """ 

92 return Pool.get_pool(Pool.DEFAULT_POOL_NAME, session=session) 

93 

94 @staticmethod 

95 @provide_session 

96 def is_default_pool(id: int, session: Session = NEW_SESSION) -> bool: 

97 """ 

98 Check id if is the default_pool. 

99 

100 :param id: pool id 

101 :param session: SQLAlchemy ORM Session 

102 :return: True if id is default_pool, otherwise False 

103 """ 

104 return exists_query( 

105 Pool.id == id, 

106 Pool.pool == Pool.DEFAULT_POOL_NAME, 

107 session=session, 

108 ) 

109 

110 @staticmethod 

111 @provide_session 

112 def create_or_update_pool( 

113 name: str, 

114 slots: int, 

115 description: str, 

116 include_deferred: bool, 

117 session: Session = NEW_SESSION, 

118 ) -> Pool: 

119 """Create a pool with given parameters or update it if it already exists.""" 

120 if not name: 

121 raise ValueError("Pool name must not be empty") 

122 

123 pool = session.scalar(select(Pool).filter_by(pool=name)) 

124 if pool is None: 

125 pool = Pool(pool=name, slots=slots, description=description, include_deferred=include_deferred) 

126 session.add(pool) 

127 else: 

128 pool.slots = slots 

129 pool.description = description 

130 pool.include_deferred = include_deferred 

131 

132 session.commit() 

133 return pool 

134 

135 @staticmethod 

136 @provide_session 

137 def delete_pool(name: str, session: Session = NEW_SESSION) -> Pool: 

138 """Delete pool by a given name.""" 

139 if name == Pool.DEFAULT_POOL_NAME: 

140 raise AirflowException(f"{Pool.DEFAULT_POOL_NAME} cannot be deleted") 

141 

142 pool = session.scalar(select(Pool).filter_by(pool=name)) 

143 if pool is None: 

144 raise PoolNotFound(f"Pool '{name}' doesn't exist") 

145 

146 session.delete(pool) 

147 session.commit() 

148 

149 return pool 

150 

151 @staticmethod 

152 @provide_session 

153 def slots_stats( 

154 *, 

155 lock_rows: bool = False, 

156 session: Session = NEW_SESSION, 

157 ) -> dict[str, PoolStats]: 

158 """ 

159 Get Pool stats (Number of Running, Queued, Open & Total tasks). 

160 

161 If ``lock_rows`` is True, and the database engine in use supports the ``NOWAIT`` syntax, then a 

162 non-blocking lock will be attempted -- if the lock is not available then SQLAlchemy will throw an 

163 OperationalError. 

164 

165 :param lock_rows: Should we attempt to obtain a row-level lock on all the Pool rows returns 

166 :param session: SQLAlchemy ORM Session 

167 """ 

168 from airflow.models.taskinstance import TaskInstance # Avoid circular import 

169 

170 pools: dict[str, PoolStats] = {} 

171 pool_includes_deferred: dict[str, bool] = {} 

172 

173 query = select(Pool.pool, Pool.slots, Pool.include_deferred) 

174 

175 if lock_rows: 

176 query = with_row_locks(query, session=session, nowait=True) 

177 

178 pool_rows = session.execute(query) 

179 for pool_name, total_slots, include_deferred in pool_rows: 

180 if total_slots == -1: 

181 total_slots = float("inf") # type: ignore 

182 pools[pool_name] = PoolStats( 

183 total=total_slots, running=0, queued=0, open=0, deferred=0, scheduled=0 

184 ) 

185 pool_includes_deferred[pool_name] = include_deferred 

186 

187 allowed_execution_states = EXECUTION_STATES | { 

188 TaskInstanceState.DEFERRED, 

189 TaskInstanceState.SCHEDULED, 

190 } 

191 state_count_by_pool = session.execute( 

192 select(TaskInstance.pool, TaskInstance.state, func.sum(TaskInstance.pool_slots)) 

193 .filter(TaskInstance.state.in_(allowed_execution_states)) 

194 .group_by(TaskInstance.pool, TaskInstance.state) 

195 ) 

196 

197 # calculate queued and running metrics 

198 for pool_name, state, count in state_count_by_pool: 

199 # Some databases return decimal.Decimal here. 

200 count = int(count) 

201 

202 stats_dict: PoolStats | None = pools.get(pool_name) 

203 if not stats_dict: 

204 continue 

205 # TypedDict key must be a string literal, so we use if-statements to set value 

206 if state == TaskInstanceState.RUNNING: 

207 stats_dict["running"] = count 

208 elif state == TaskInstanceState.QUEUED: 

209 stats_dict["queued"] = count 

210 elif state == TaskInstanceState.DEFERRED: 

211 stats_dict["deferred"] = count 

212 elif state == TaskInstanceState.SCHEDULED: 

213 stats_dict["scheduled"] = count 

214 else: 

215 raise AirflowException(f"Unexpected state. Expected values: {allowed_execution_states}.") 

216 

217 # calculate open metric 

218 for pool_name, stats_dict in pools.items(): 

219 stats_dict["open"] = stats_dict["total"] - stats_dict["running"] - stats_dict["queued"] 

220 if pool_includes_deferred[pool_name]: 

221 stats_dict["open"] -= stats_dict["deferred"] 

222 

223 return pools 

224 

225 def to_json(self) -> dict[str, Any]: 

226 """ 

227 Get the Pool in a json structure. 

228 

229 :return: the pool object in json format 

230 """ 

231 return { 

232 "id": self.id, 

233 "pool": self.pool, 

234 "slots": self.slots, 

235 "description": self.description, 

236 "include_deferred": self.include_deferred, 

237 } 

238 

239 @provide_session 

240 def occupied_slots(self, session: Session = NEW_SESSION) -> int: 

241 """ 

242 Get the number of slots used by running/queued tasks at the moment. 

243 

244 :param session: SQLAlchemy ORM Session 

245 :return: the used number of slots 

246 """ 

247 from airflow.models.taskinstance import TaskInstance # Avoid circular import 

248 

249 occupied_states = self.get_occupied_states() 

250 

251 return int( 

252 session.scalar( 

253 select(func.sum(TaskInstance.pool_slots)) 

254 .filter(TaskInstance.pool == self.pool) 

255 .filter(TaskInstance.state.in_(occupied_states)) 

256 ) 

257 or 0 

258 ) 

259 

260 def get_occupied_states(self): 

261 if self.include_deferred: 

262 return EXECUTION_STATES | { 

263 TaskInstanceState.DEFERRED, 

264 } 

265 return EXECUTION_STATES 

266 

267 @provide_session 

268 def running_slots(self, session: Session = NEW_SESSION) -> int: 

269 """ 

270 Get the number of slots used by running tasks at the moment. 

271 

272 :param session: SQLAlchemy ORM Session 

273 :return: the used number of slots 

274 """ 

275 from airflow.models.taskinstance import TaskInstance # Avoid circular import 

276 

277 return int( 

278 session.scalar( 

279 select(func.sum(TaskInstance.pool_slots)) 

280 .filter(TaskInstance.pool == self.pool) 

281 .filter(TaskInstance.state == TaskInstanceState.RUNNING) 

282 ) 

283 or 0 

284 ) 

285 

286 @provide_session 

287 def queued_slots(self, session: Session = NEW_SESSION) -> int: 

288 """ 

289 Get the number of slots used by queued tasks at the moment. 

290 

291 :param session: SQLAlchemy ORM Session 

292 :return: the used number of slots 

293 """ 

294 from airflow.models.taskinstance import TaskInstance # Avoid circular import 

295 

296 return int( 

297 session.scalar( 

298 select(func.sum(TaskInstance.pool_slots)) 

299 .filter(TaskInstance.pool == self.pool) 

300 .filter(TaskInstance.state == TaskInstanceState.QUEUED) 

301 ) 

302 or 0 

303 ) 

304 

305 @provide_session 

306 def scheduled_slots(self, session: Session = NEW_SESSION) -> int: 

307 """ 

308 Get the number of slots scheduled at the moment. 

309 

310 :param session: SQLAlchemy ORM Session 

311 :return: the number of scheduled slots 

312 """ 

313 from airflow.models.taskinstance import TaskInstance # Avoid circular import 

314 

315 return int( 

316 session.scalar( 

317 select(func.sum(TaskInstance.pool_slots)) 

318 .filter(TaskInstance.pool == self.pool) 

319 .filter(TaskInstance.state == TaskInstanceState.SCHEDULED) 

320 ) 

321 or 0 

322 ) 

323 

324 @provide_session 

325 def deferred_slots(self, session: Session = NEW_SESSION) -> int: 

326 """ 

327 Get the number of slots deferred at the moment. 

328 

329 :param session: SQLAlchemy ORM Session 

330 :return: the number of deferred slots 

331 """ 

332 from airflow.models.taskinstance import TaskInstance # Avoid circular import 

333 

334 return int( 

335 session.scalar( 

336 select(func.sum(TaskInstance.pool_slots)).where( 

337 TaskInstance.pool == self.pool, TaskInstance.state == TaskInstanceState.DEFERRED 

338 ) 

339 ) 

340 or 0 

341 ) 

342 

343 @provide_session 

344 def open_slots(self, session: Session = NEW_SESSION) -> float: 

345 """ 

346 Get the number of slots open at the moment. 

347 

348 :param session: SQLAlchemy ORM Session 

349 :return: the number of slots 

350 """ 

351 if self.slots == -1: 

352 return float("inf") 

353 return self.slots - self.occupied_slots(session)