1#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18"""This module defines dep for pool slots availability."""
19
20from __future__ import annotations
21
22from sqlalchemy import select
23
24from airflow.ti_deps.deps.base_ti_dep import BaseTIDep
25from airflow.utils.session import provide_session
26
27
28class PoolSlotsAvailableDep(BaseTIDep):
29 """Dep for pool slots availability."""
30
31 NAME = "Pool Slots Available"
32 IGNORABLE = True
33
34 @provide_session
35 def _get_dep_statuses(self, ti, session, dep_context=None):
36 """
37 Determine if the pool task instance is in has available slots.
38
39 :param ti: the task instance to get the dependency status for
40 :param session: database session
41 :param dep_context: the context for which this dependency should be evaluated for
42 :return: True if there are available slots in the pool.
43 """
44 from airflow.models.pool import Pool # To avoid a circular dependency
45
46 pool_name = ti.pool
47
48 # Controlled by UNIQUE key in slot_pool table, only (at most) one result can be returned.
49 pool: Pool | None = session.scalar(select(Pool).where(Pool.pool == pool_name))
50 if pool is None:
51 yield self._failing_status(
52 reason=f"Tasks using non-existent pool '{pool_name}' will not be scheduled"
53 )
54 return
55
56 open_slots = pool.open_slots(session=session)
57 if ti.state in pool.get_occupied_states():
58 open_slots += ti.pool_slots
59
60 if open_slots <= (ti.pool_slots - 1):
61 yield self._failing_status(
62 reason=f"Not scheduling since there are {open_slots} open slots in pool {pool_name} "
63 f"and require {ti.pool_slots} pool slots"
64 )
65 else:
66 yield self._passing_status(
67 reason=f"There are enough open slots in {pool_name} to execute the task",
68 )