Coverage for /pythoncovmergedfiles/medio/medio/src/airflow/build/lib/airflow/models/pool.py: 49%
115 statements
« prev ^ index » next coverage.py v7.0.1, created at 2022-12-25 06:11 +0000
« prev ^ index » next coverage.py v7.0.1, created at 2022-12-25 06:11 +0000
1#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18from __future__ import annotations
20from typing import Iterable
22from sqlalchemy import Column, Integer, String, Text, func
23from sqlalchemy.orm.session import Session
25from airflow.exceptions import AirflowException, PoolNotFound
26from airflow.models.base import Base
27from airflow.ti_deps.dependencies_states import EXECUTION_STATES
28from airflow.typing_compat import TypedDict
29from airflow.utils.session import NEW_SESSION, provide_session
30from airflow.utils.sqlalchemy import nowait, with_row_locks
31from airflow.utils.state import State
34class PoolStats(TypedDict):
35 """Dictionary containing Pool Stats"""
37 total: int
38 running: int
39 queued: int
40 open: int
43class Pool(Base):
44 """the class to get Pool info."""
46 __tablename__ = "slot_pool"
48 id = Column(Integer, primary_key=True)
49 pool = Column(String(256), unique=True)
50 # -1 for infinite
51 slots = Column(Integer, default=0)
52 description = Column(Text)
54 DEFAULT_POOL_NAME = "default_pool"
56 def __repr__(self):
57 return str(self.pool)
59 @staticmethod
60 @provide_session
61 def get_pools(session: Session = NEW_SESSION):
62 """Get all pools."""
63 return session.query(Pool).all()
65 @staticmethod
66 @provide_session
67 def get_pool(pool_name: str, session: Session = NEW_SESSION):
68 """
69 Get the Pool with specific pool name from the Pools.
71 :param pool_name: The pool name of the Pool to get.
72 :param session: SQLAlchemy ORM Session
73 :return: the pool object
74 """
75 return session.query(Pool).filter(Pool.pool == pool_name).first()
77 @staticmethod
78 @provide_session
79 def get_default_pool(session: Session = NEW_SESSION):
80 """
81 Get the Pool of the default_pool from the Pools.
83 :param session: SQLAlchemy ORM Session
84 :return: the pool object
85 """
86 return Pool.get_pool(Pool.DEFAULT_POOL_NAME, session=session)
88 @staticmethod
89 @provide_session
90 def is_default_pool(id: int, session: Session = NEW_SESSION) -> bool:
91 """
92 Check id if is the default_pool.
94 :param id: pool id
95 :param session: SQLAlchemy ORM Session
96 :return: True if id is default_pool, otherwise False
97 """
98 return (
99 session.query(func.count(Pool.id))
100 .filter(Pool.id == id, Pool.pool == Pool.DEFAULT_POOL_NAME)
101 .scalar()
102 > 0
103 )
105 @staticmethod
106 @provide_session
107 def create_or_update_pool(name: str, slots: int, description: str, session: Session = NEW_SESSION):
108 """Create a pool with given parameters or update it if it already exists."""
109 if not name:
110 return
111 pool = session.query(Pool).filter_by(pool=name).first()
112 if pool is None:
113 pool = Pool(pool=name, slots=slots, description=description)
114 session.add(pool)
115 else:
116 pool.slots = slots
117 pool.description = description
119 session.commit()
121 return pool
123 @staticmethod
124 @provide_session
125 def delete_pool(name: str, session: Session = NEW_SESSION):
126 """Delete pool by a given name."""
127 if name == Pool.DEFAULT_POOL_NAME:
128 raise AirflowException(f"{Pool.DEFAULT_POOL_NAME} cannot be deleted")
130 pool = session.query(Pool).filter_by(pool=name).first()
131 if pool is None:
132 raise PoolNotFound(f"Pool '{name}' doesn't exist")
134 session.delete(pool)
135 session.commit()
137 return pool
139 @staticmethod
140 @provide_session
141 def slots_stats(
142 *,
143 lock_rows: bool = False,
144 session: Session = NEW_SESSION,
145 ) -> dict[str, PoolStats]:
146 """
147 Get Pool stats (Number of Running, Queued, Open & Total tasks)
149 If ``lock_rows`` is True, and the database engine in use supports the ``NOWAIT`` syntax, then a
150 non-blocking lock will be attempted -- if the lock is not available then SQLAlchemy will throw an
151 OperationalError.
153 :param lock_rows: Should we attempt to obtain a row-level lock on all the Pool rows returns
154 :param session: SQLAlchemy ORM Session
155 """
156 from airflow.models.taskinstance import TaskInstance # Avoid circular import
158 pools: dict[str, PoolStats] = {}
160 query = session.query(Pool.pool, Pool.slots)
162 if lock_rows:
163 query = with_row_locks(query, session=session, **nowait(session))
165 pool_rows: Iterable[tuple[str, int]] = query.all()
166 for (pool_name, total_slots) in pool_rows:
167 if total_slots == -1:
168 total_slots = float("inf") # type: ignore
169 pools[pool_name] = PoolStats(total=total_slots, running=0, queued=0, open=0)
171 state_count_by_pool = (
172 session.query(TaskInstance.pool, TaskInstance.state, func.sum(TaskInstance.pool_slots))
173 .filter(TaskInstance.state.in_(list(EXECUTION_STATES)))
174 .group_by(TaskInstance.pool, TaskInstance.state)
175 ).all()
177 # calculate queued and running metrics
178 for (pool_name, state, count) in state_count_by_pool:
179 # Some databases return decimal.Decimal here.
180 count = int(count)
182 stats_dict: PoolStats | None = pools.get(pool_name)
183 if not stats_dict:
184 continue
185 # TypedDict key must be a string literal, so we use if-statements to set value
186 if state == "running":
187 stats_dict["running"] = count
188 elif state == "queued":
189 stats_dict["queued"] = count
190 else:
191 raise AirflowException(f"Unexpected state. Expected values: {EXECUTION_STATES}.")
193 # calculate open metric
194 for pool_name, stats_dict in pools.items():
195 stats_dict["open"] = stats_dict["total"] - stats_dict["running"] - stats_dict["queued"]
197 return pools
199 def to_json(self):
200 """
201 Get the Pool in a json structure
203 :return: the pool object in json format
204 """
205 return {
206 "id": self.id,
207 "pool": self.pool,
208 "slots": self.slots,
209 "description": self.description,
210 }
212 @provide_session
213 def occupied_slots(self, session: Session = NEW_SESSION):
214 """
215 Get the number of slots used by running/queued tasks at the moment.
217 :param session: SQLAlchemy ORM Session
218 :return: the used number of slots
219 """
220 from airflow.models.taskinstance import TaskInstance # Avoid circular import
222 return int(
223 session.query(func.sum(TaskInstance.pool_slots))
224 .filter(TaskInstance.pool == self.pool)
225 .filter(TaskInstance.state.in_(list(EXECUTION_STATES)))
226 .scalar()
227 or 0
228 )
230 @provide_session
231 def running_slots(self, session: Session = NEW_SESSION):
232 """
233 Get the number of slots used by running tasks at the moment.
235 :param session: SQLAlchemy ORM Session
236 :return: the used number of slots
237 """
238 from airflow.models.taskinstance import TaskInstance # Avoid circular import
240 return int(
241 session.query(func.sum(TaskInstance.pool_slots))
242 .filter(TaskInstance.pool == self.pool)
243 .filter(TaskInstance.state == State.RUNNING)
244 .scalar()
245 or 0
246 )
248 @provide_session
249 def queued_slots(self, session: Session = NEW_SESSION):
250 """
251 Get the number of slots used by queued tasks at the moment.
253 :param session: SQLAlchemy ORM Session
254 :return: the used number of slots
255 """
256 from airflow.models.taskinstance import TaskInstance # Avoid circular import
258 return int(
259 session.query(func.sum(TaskInstance.pool_slots))
260 .filter(TaskInstance.pool == self.pool)
261 .filter(TaskInstance.state == State.QUEUED)
262 .scalar()
263 or 0
264 )
266 @provide_session
267 def scheduled_slots(self, session: Session = NEW_SESSION):
268 """
269 Get the number of slots scheduled at the moment.
271 :param session: SQLAlchemy ORM Session
272 :return: the number of scheduled slots
273 """
274 from airflow.models.taskinstance import TaskInstance # Avoid circular import
276 return int(
277 session.query(func.sum(TaskInstance.pool_slots))
278 .filter(TaskInstance.pool == self.pool)
279 .filter(TaskInstance.state == State.SCHEDULED)
280 .scalar()
281 or 0
282 )
284 @provide_session
285 def open_slots(self, session: Session = NEW_SESSION) -> float:
286 """
287 Get the number of slots open at the moment.
289 :param session: SQLAlchemy ORM Session
290 :return: the number of slots
291 """
292 if self.slots == -1:
293 return float("inf")
294 else:
295 return self.slots - self.occupied_slots(session)