Coverage for /pythoncovmergedfiles/medio/medio/src/airflow/build/lib/airflow/models/pool.py: 49%
115 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-07 06:35 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-07 06:35 +0000
1#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18from __future__ import annotations
20from typing import Any, Iterable
22from sqlalchemy import Column, Integer, String, Text, func
23from sqlalchemy.orm.session import Session
25from airflow.exceptions import AirflowException, PoolNotFound
26from airflow.models.base import Base
27from airflow.ti_deps.dependencies_states import EXECUTION_STATES
28from airflow.typing_compat import TypedDict
29from airflow.utils.session import NEW_SESSION, provide_session
30from airflow.utils.sqlalchemy import nowait, with_row_locks
31from airflow.utils.state import State
34class PoolStats(TypedDict):
35 """Dictionary containing Pool Stats."""
37 total: int
38 running: int
39 queued: int
40 open: int
43class Pool(Base):
44 """the class to get Pool info."""
46 __tablename__ = "slot_pool"
48 id = Column(Integer, primary_key=True)
49 pool = Column(String(256), unique=True)
50 # -1 for infinite
51 slots = Column(Integer, default=0)
52 description = Column(Text)
54 DEFAULT_POOL_NAME = "default_pool"
56 def __repr__(self):
57 return str(self.pool)
59 @staticmethod
60 @provide_session
61 def get_pools(session: Session = NEW_SESSION) -> list[Pool]:
62 """Get all pools."""
63 return session.query(Pool).all()
65 @staticmethod
66 @provide_session
67 def get_pool(pool_name: str, session: Session = NEW_SESSION) -> Pool | None:
68 """
69 Get the Pool with specific pool name from the Pools.
71 :param pool_name: The pool name of the Pool to get.
72 :param session: SQLAlchemy ORM Session
73 :return: the pool object
74 """
75 return session.query(Pool).filter(Pool.pool == pool_name).first()
77 @staticmethod
78 @provide_session
79 def get_default_pool(session: Session = NEW_SESSION) -> Pool | None:
80 """
81 Get the Pool of the default_pool from the Pools.
83 :param session: SQLAlchemy ORM Session
84 :return: the pool object
85 """
86 return Pool.get_pool(Pool.DEFAULT_POOL_NAME, session=session)
88 @staticmethod
89 @provide_session
90 def is_default_pool(id: int, session: Session = NEW_SESSION) -> bool:
91 """
92 Check id if is the default_pool.
94 :param id: pool id
95 :param session: SQLAlchemy ORM Session
96 :return: True if id is default_pool, otherwise False
97 """
98 return (
99 session.query(func.count(Pool.id))
100 .filter(Pool.id == id, Pool.pool == Pool.DEFAULT_POOL_NAME)
101 .scalar()
102 > 0
103 )
105 @staticmethod
106 @provide_session
107 def create_or_update_pool(
108 name: str,
109 slots: int,
110 description: str,
111 session: Session = NEW_SESSION,
112 ) -> Pool:
113 """Create a pool with given parameters or update it if it already exists."""
114 if not name:
115 raise ValueError("Pool name must not be empty")
117 pool = session.query(Pool).filter_by(pool=name).one_or_none()
118 if pool is None:
119 pool = Pool(pool=name, slots=slots, description=description)
120 session.add(pool)
121 else:
122 pool.slots = slots
123 pool.description = description
125 session.commit()
126 return pool
128 @staticmethod
129 @provide_session
130 def delete_pool(name: str, session: Session = NEW_SESSION) -> Pool:
131 """Delete pool by a given name."""
132 if name == Pool.DEFAULT_POOL_NAME:
133 raise AirflowException(f"{Pool.DEFAULT_POOL_NAME} cannot be deleted")
135 pool = session.query(Pool).filter_by(pool=name).first()
136 if pool is None:
137 raise PoolNotFound(f"Pool '{name}' doesn't exist")
139 session.delete(pool)
140 session.commit()
142 return pool
144 @staticmethod
145 @provide_session
146 def slots_stats(
147 *,
148 lock_rows: bool = False,
149 session: Session = NEW_SESSION,
150 ) -> dict[str, PoolStats]:
151 """
152 Get Pool stats (Number of Running, Queued, Open & Total tasks).
154 If ``lock_rows`` is True, and the database engine in use supports the ``NOWAIT`` syntax, then a
155 non-blocking lock will be attempted -- if the lock is not available then SQLAlchemy will throw an
156 OperationalError.
158 :param lock_rows: Should we attempt to obtain a row-level lock on all the Pool rows returns
159 :param session: SQLAlchemy ORM Session
160 """
161 from airflow.models.taskinstance import TaskInstance # Avoid circular import
163 pools: dict[str, PoolStats] = {}
165 query = session.query(Pool.pool, Pool.slots)
167 if lock_rows:
168 query = with_row_locks(query, session=session, **nowait(session))
170 pool_rows: Iterable[tuple[str, int]] = query.all()
171 for (pool_name, total_slots) in pool_rows:
172 if total_slots == -1:
173 total_slots = float("inf") # type: ignore
174 pools[pool_name] = PoolStats(total=total_slots, running=0, queued=0, open=0)
176 state_count_by_pool = (
177 session.query(TaskInstance.pool, TaskInstance.state, func.sum(TaskInstance.pool_slots))
178 .filter(TaskInstance.state.in_(list(EXECUTION_STATES)))
179 .group_by(TaskInstance.pool, TaskInstance.state)
180 ).all()
182 # calculate queued and running metrics
183 for (pool_name, state, count) in state_count_by_pool:
184 # Some databases return decimal.Decimal here.
185 count = int(count)
187 stats_dict: PoolStats | None = pools.get(pool_name)
188 if not stats_dict:
189 continue
190 # TypedDict key must be a string literal, so we use if-statements to set value
191 if state == "running":
192 stats_dict["running"] = count
193 elif state == "queued":
194 stats_dict["queued"] = count
195 else:
196 raise AirflowException(f"Unexpected state. Expected values: {EXECUTION_STATES}.")
198 # calculate open metric
199 for pool_name, stats_dict in pools.items():
200 stats_dict["open"] = stats_dict["total"] - stats_dict["running"] - stats_dict["queued"]
202 return pools
204 def to_json(self) -> dict[str, Any]:
205 """
206 Get the Pool in a json structure.
208 :return: the pool object in json format
209 """
210 return {
211 "id": self.id,
212 "pool": self.pool,
213 "slots": self.slots,
214 "description": self.description,
215 }
217 @provide_session
218 def occupied_slots(self, session: Session = NEW_SESSION) -> int:
219 """
220 Get the number of slots used by running/queued tasks at the moment.
222 :param session: SQLAlchemy ORM Session
223 :return: the used number of slots
224 """
225 from airflow.models.taskinstance import TaskInstance # Avoid circular import
227 return int(
228 session.query(func.sum(TaskInstance.pool_slots))
229 .filter(TaskInstance.pool == self.pool)
230 .filter(TaskInstance.state.in_(EXECUTION_STATES))
231 .scalar()
232 or 0
233 )
235 @provide_session
236 def running_slots(self, session: Session = NEW_SESSION) -> int:
237 """
238 Get the number of slots used by running tasks at the moment.
240 :param session: SQLAlchemy ORM Session
241 :return: the used number of slots
242 """
243 from airflow.models.taskinstance import TaskInstance # Avoid circular import
245 return int(
246 session.query(func.sum(TaskInstance.pool_slots))
247 .filter(TaskInstance.pool == self.pool)
248 .filter(TaskInstance.state == State.RUNNING)
249 .scalar()
250 or 0
251 )
253 @provide_session
254 def queued_slots(self, session: Session = NEW_SESSION) -> int:
255 """
256 Get the number of slots used by queued tasks at the moment.
258 :param session: SQLAlchemy ORM Session
259 :return: the used number of slots
260 """
261 from airflow.models.taskinstance import TaskInstance # Avoid circular import
263 return int(
264 session.query(func.sum(TaskInstance.pool_slots))
265 .filter(TaskInstance.pool == self.pool)
266 .filter(TaskInstance.state == State.QUEUED)
267 .scalar()
268 or 0
269 )
271 @provide_session
272 def scheduled_slots(self, session: Session = NEW_SESSION) -> int:
273 """
274 Get the number of slots scheduled at the moment.
276 :param session: SQLAlchemy ORM Session
277 :return: the number of scheduled slots
278 """
279 from airflow.models.taskinstance import TaskInstance # Avoid circular import
281 return int(
282 session.query(func.sum(TaskInstance.pool_slots))
283 .filter(TaskInstance.pool == self.pool)
284 .filter(TaskInstance.state == State.SCHEDULED)
285 .scalar()
286 or 0
287 )
289 @provide_session
290 def open_slots(self, session: Session = NEW_SESSION) -> float:
291 """
292 Get the number of slots open at the moment.
294 :param session: SQLAlchemy ORM Session
295 :return: the number of slots
296 """
297 if self.slots == -1:
298 return float("inf")
299 return self.slots - self.occupied_slots(session)