1#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18from __future__ import annotations
19
20from typing import TYPE_CHECKING, Any
21
22from sqlalchemy import Boolean, Column, Integer, String, Text, func, select
23
24from airflow.exceptions import AirflowException, PoolNotFound
25from airflow.models.base import Base
26from airflow.ti_deps.dependencies_states import EXECUTION_STATES
27from airflow.typing_compat import TypedDict
28from airflow.utils.db import exists_query
29from airflow.utils.session import NEW_SESSION, provide_session
30from airflow.utils.sqlalchemy import with_row_locks
31from airflow.utils.state import TaskInstanceState
32
33if TYPE_CHECKING:
34 from sqlalchemy.orm.session import Session
35
36
37class PoolStats(TypedDict):
38 """Dictionary containing Pool Stats."""
39
40 total: int
41 running: int
42 deferred: int
43 queued: int
44 open: int
45 scheduled: int
46
47
48class Pool(Base):
49 """the class to get Pool info."""
50
51 __tablename__ = "slot_pool"
52
53 id = Column(Integer, primary_key=True)
54 pool = Column(String(256), unique=True)
55 # -1 for infinite
56 slots = Column(Integer, default=0)
57 description = Column(Text)
58 include_deferred = Column(Boolean, nullable=False)
59
60 DEFAULT_POOL_NAME = "default_pool"
61
62 def __repr__(self):
63 return str(self.pool)
64
65 @staticmethod
66 @provide_session
67 def get_pools(session: Session = NEW_SESSION) -> list[Pool]:
68 """Get all pools."""
69 return session.scalars(select(Pool)).all()
70
71 @staticmethod
72 @provide_session
73 def get_pool(pool_name: str, session: Session = NEW_SESSION) -> Pool | None:
74 """
75 Get the Pool with specific pool name from the Pools.
76
77 :param pool_name: The pool name of the Pool to get.
78 :param session: SQLAlchemy ORM Session
79 :return: the pool object
80 """
81 return session.scalar(select(Pool).where(Pool.pool == pool_name))
82
83 @staticmethod
84 @provide_session
85 def get_default_pool(session: Session = NEW_SESSION) -> Pool | None:
86 """
87 Get the Pool of the default_pool from the Pools.
88
89 :param session: SQLAlchemy ORM Session
90 :return: the pool object
91 """
92 return Pool.get_pool(Pool.DEFAULT_POOL_NAME, session=session)
93
94 @staticmethod
95 @provide_session
96 def is_default_pool(id: int, session: Session = NEW_SESSION) -> bool:
97 """
98 Check id if is the default_pool.
99
100 :param id: pool id
101 :param session: SQLAlchemy ORM Session
102 :return: True if id is default_pool, otherwise False
103 """
104 return exists_query(
105 Pool.id == id,
106 Pool.pool == Pool.DEFAULT_POOL_NAME,
107 session=session,
108 )
109
110 @staticmethod
111 @provide_session
112 def create_or_update_pool(
113 name: str,
114 slots: int,
115 description: str,
116 include_deferred: bool,
117 session: Session = NEW_SESSION,
118 ) -> Pool:
119 """Create a pool with given parameters or update it if it already exists."""
120 if not name:
121 raise ValueError("Pool name must not be empty")
122
123 pool = session.scalar(select(Pool).filter_by(pool=name))
124 if pool is None:
125 pool = Pool(pool=name, slots=slots, description=description, include_deferred=include_deferred)
126 session.add(pool)
127 else:
128 pool.slots = slots
129 pool.description = description
130 pool.include_deferred = include_deferred
131
132 session.commit()
133 return pool
134
135 @staticmethod
136 @provide_session
137 def delete_pool(name: str, session: Session = NEW_SESSION) -> Pool:
138 """Delete pool by a given name."""
139 if name == Pool.DEFAULT_POOL_NAME:
140 raise AirflowException(f"{Pool.DEFAULT_POOL_NAME} cannot be deleted")
141
142 pool = session.scalar(select(Pool).filter_by(pool=name))
143 if pool is None:
144 raise PoolNotFound(f"Pool '{name}' doesn't exist")
145
146 session.delete(pool)
147 session.commit()
148
149 return pool
150
151 @staticmethod
152 @provide_session
153 def slots_stats(
154 *,
155 lock_rows: bool = False,
156 session: Session = NEW_SESSION,
157 ) -> dict[str, PoolStats]:
158 """
159 Get Pool stats (Number of Running, Queued, Open & Total tasks).
160
161 If ``lock_rows`` is True, and the database engine in use supports the ``NOWAIT`` syntax, then a
162 non-blocking lock will be attempted -- if the lock is not available then SQLAlchemy will throw an
163 OperationalError.
164
165 :param lock_rows: Should we attempt to obtain a row-level lock on all the Pool rows returns
166 :param session: SQLAlchemy ORM Session
167 """
168 from airflow.models.taskinstance import TaskInstance # Avoid circular import
169
170 pools: dict[str, PoolStats] = {}
171 pool_includes_deferred: dict[str, bool] = {}
172
173 query = select(Pool.pool, Pool.slots, Pool.include_deferred)
174
175 if lock_rows:
176 query = with_row_locks(query, session=session, nowait=True)
177
178 pool_rows = session.execute(query)
179 for pool_name, total_slots, include_deferred in pool_rows:
180 if total_slots == -1:
181 total_slots = float("inf") # type: ignore
182 pools[pool_name] = PoolStats(
183 total=total_slots, running=0, queued=0, open=0, deferred=0, scheduled=0
184 )
185 pool_includes_deferred[pool_name] = include_deferred
186
187 allowed_execution_states = EXECUTION_STATES | {
188 TaskInstanceState.DEFERRED,
189 TaskInstanceState.SCHEDULED,
190 }
191 state_count_by_pool = session.execute(
192 select(TaskInstance.pool, TaskInstance.state, func.sum(TaskInstance.pool_slots))
193 .filter(TaskInstance.state.in_(allowed_execution_states))
194 .group_by(TaskInstance.pool, TaskInstance.state)
195 )
196
197 # calculate queued and running metrics
198 for pool_name, state, count in state_count_by_pool:
199 # Some databases return decimal.Decimal here.
200 count = int(count)
201
202 stats_dict: PoolStats | None = pools.get(pool_name)
203 if not stats_dict:
204 continue
205 # TypedDict key must be a string literal, so we use if-statements to set value
206 if state == TaskInstanceState.RUNNING:
207 stats_dict["running"] = count
208 elif state == TaskInstanceState.QUEUED:
209 stats_dict["queued"] = count
210 elif state == TaskInstanceState.DEFERRED:
211 stats_dict["deferred"] = count
212 elif state == TaskInstanceState.SCHEDULED:
213 stats_dict["scheduled"] = count
214 else:
215 raise AirflowException(f"Unexpected state. Expected values: {allowed_execution_states}.")
216
217 # calculate open metric
218 for pool_name, stats_dict in pools.items():
219 stats_dict["open"] = stats_dict["total"] - stats_dict["running"] - stats_dict["queued"]
220 if pool_includes_deferred[pool_name]:
221 stats_dict["open"] -= stats_dict["deferred"]
222
223 return pools
224
225 def to_json(self) -> dict[str, Any]:
226 """
227 Get the Pool in a json structure.
228
229 :return: the pool object in json format
230 """
231 return {
232 "id": self.id,
233 "pool": self.pool,
234 "slots": self.slots,
235 "description": self.description,
236 "include_deferred": self.include_deferred,
237 }
238
239 @provide_session
240 def occupied_slots(self, session: Session = NEW_SESSION) -> int:
241 """
242 Get the number of slots used by running/queued tasks at the moment.
243
244 :param session: SQLAlchemy ORM Session
245 :return: the used number of slots
246 """
247 from airflow.models.taskinstance import TaskInstance # Avoid circular import
248
249 occupied_states = self.get_occupied_states()
250
251 return int(
252 session.scalar(
253 select(func.sum(TaskInstance.pool_slots))
254 .filter(TaskInstance.pool == self.pool)
255 .filter(TaskInstance.state.in_(occupied_states))
256 )
257 or 0
258 )
259
260 def get_occupied_states(self):
261 if self.include_deferred:
262 return EXECUTION_STATES | {
263 TaskInstanceState.DEFERRED,
264 }
265 return EXECUTION_STATES
266
267 @provide_session
268 def running_slots(self, session: Session = NEW_SESSION) -> int:
269 """
270 Get the number of slots used by running tasks at the moment.
271
272 :param session: SQLAlchemy ORM Session
273 :return: the used number of slots
274 """
275 from airflow.models.taskinstance import TaskInstance # Avoid circular import
276
277 return int(
278 session.scalar(
279 select(func.sum(TaskInstance.pool_slots))
280 .filter(TaskInstance.pool == self.pool)
281 .filter(TaskInstance.state == TaskInstanceState.RUNNING)
282 )
283 or 0
284 )
285
286 @provide_session
287 def queued_slots(self, session: Session = NEW_SESSION) -> int:
288 """
289 Get the number of slots used by queued tasks at the moment.
290
291 :param session: SQLAlchemy ORM Session
292 :return: the used number of slots
293 """
294 from airflow.models.taskinstance import TaskInstance # Avoid circular import
295
296 return int(
297 session.scalar(
298 select(func.sum(TaskInstance.pool_slots))
299 .filter(TaskInstance.pool == self.pool)
300 .filter(TaskInstance.state == TaskInstanceState.QUEUED)
301 )
302 or 0
303 )
304
305 @provide_session
306 def scheduled_slots(self, session: Session = NEW_SESSION) -> int:
307 """
308 Get the number of slots scheduled at the moment.
309
310 :param session: SQLAlchemy ORM Session
311 :return: the number of scheduled slots
312 """
313 from airflow.models.taskinstance import TaskInstance # Avoid circular import
314
315 return int(
316 session.scalar(
317 select(func.sum(TaskInstance.pool_slots))
318 .filter(TaskInstance.pool == self.pool)
319 .filter(TaskInstance.state == TaskInstanceState.SCHEDULED)
320 )
321 or 0
322 )
323
324 @provide_session
325 def deferred_slots(self, session: Session = NEW_SESSION) -> int:
326 """
327 Get the number of slots deferred at the moment.
328
329 :param session: SQLAlchemy ORM Session
330 :return: the number of deferred slots
331 """
332 from airflow.models.taskinstance import TaskInstance # Avoid circular import
333
334 return int(
335 session.scalar(
336 select(func.sum(TaskInstance.pool_slots)).where(
337 TaskInstance.pool == self.pool, TaskInstance.state == TaskInstanceState.DEFERRED
338 )
339 )
340 or 0
341 )
342
343 @provide_session
344 def open_slots(self, session: Session = NEW_SESSION) -> float:
345 """
346 Get the number of slots open at the moment.
347
348 :param session: SQLAlchemy ORM Session
349 :return: the number of slots
350 """
351 if self.slots == -1:
352 return float("inf")
353 return self.slots - self.occupied_slots(session)