1#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18"""Priority weight strategies for task scheduling."""
19
20from __future__ import annotations
21
22from abc import ABC, abstractmethod
23from typing import TYPE_CHECKING, Any
24
25from airflow.task.weight_rule import WeightRule
26
27if TYPE_CHECKING:
28 from airflow.models.taskinstance import TaskInstance
29
30
31class PriorityWeightStrategy(ABC):
32 """
33 Priority weight strategy interface.
34
35 This feature is experimental and subject to change at any time.
36
37 Currently, we don't serialize the priority weight strategy parameters. This means that
38 the priority weight strategy must be stateless, but you can add class attributes, and
39 create multiple subclasses with different attributes values if you need to create
40 different versions of the same strategy.
41 """
42
43 @abstractmethod
44 def get_weight(self, ti: TaskInstance):
45 """Get the priority weight of a task."""
46 ...
47
48 @classmethod
49 def deserialize(cls, data: dict[str, Any]) -> PriorityWeightStrategy:
50 """
51 Deserialize a priority weight strategy from data.
52
53 This is called when a serialized DAG is deserialized. ``data`` will be whatever
54 was returned by ``serialize`` during DAG serialization. The default
55 implementation constructs the priority weight strategy without any arguments.
56 """
57 return cls(**data)
58
59 def serialize(self) -> dict[str, Any]:
60 """
61 Serialize the priority weight strategy for JSON encoding.
62
63 This is called during DAG serialization to store priority weight strategy information
64 in the database. This should return a JSON-serializable dict that will be fed into
65 ``deserialize`` when the DAG is deserialized. The default implementation returns
66 an empty dict.
67 """
68 return {}
69
70 def __eq__(self, other: object) -> bool:
71 """Equality comparison."""
72 if not isinstance(other, type(self)):
73 return False
74 return self.serialize() == other.serialize()
75
76 def __hash__(self):
77 return hash(self.serialize())
78
79
80class _AbsolutePriorityWeightStrategy(PriorityWeightStrategy):
81 """Priority weight strategy that uses the task's priority weight directly."""
82
83 def get_weight(self, ti: TaskInstance):
84 if TYPE_CHECKING:
85 assert ti.task
86 return ti.task.priority_weight
87
88
89class _DownstreamPriorityWeightStrategy(PriorityWeightStrategy):
90 """Priority weight strategy that uses the sum of the priority weights of all downstream tasks."""
91
92 def get_weight(self, ti: TaskInstance) -> int:
93 if TYPE_CHECKING:
94 assert ti.task
95 dag = ti.task.get_dag()
96 if dag is None:
97 return ti.task.priority_weight
98 return ti.task.priority_weight + sum(
99 dag.task_dict[task_id].priority_weight
100 for task_id in ti.task.get_flat_relative_ids(upstream=False)
101 )
102
103
104class _UpstreamPriorityWeightStrategy(PriorityWeightStrategy):
105 """Priority weight strategy that uses the sum of the priority weights of all upstream tasks."""
106
107 def get_weight(self, ti: TaskInstance):
108 if TYPE_CHECKING:
109 assert ti.task
110 dag = ti.task.get_dag()
111 if dag is None:
112 return ti.task.priority_weight
113 return ti.task.priority_weight + sum(
114 dag.task_dict[task_id].priority_weight for task_id in ti.task.get_flat_relative_ids(upstream=True)
115 )
116
117
118airflow_priority_weight_strategies: dict[str, type[PriorityWeightStrategy]] = {
119 WeightRule.ABSOLUTE: _AbsolutePriorityWeightStrategy,
120 WeightRule.DOWNSTREAM: _DownstreamPriorityWeightStrategy,
121 WeightRule.UPSTREAM: _UpstreamPriorityWeightStrategy,
122}
123
124
125airflow_priority_weight_strategies_classes = {
126 cls: name for name, cls in airflow_priority_weight_strategies.items()
127}
128
129
130def validate_and_load_priority_weight_strategy(
131 priority_weight_strategy: str | PriorityWeightStrategy | None,
132) -> PriorityWeightStrategy:
133 """
134 Validate and load a priority weight strategy.
135
136 Returns the priority weight strategy if it is valid, otherwise raises an exception.
137
138 :param priority_weight_strategy: The priority weight strategy to validate and load.
139
140 :meta private:
141 """
142 from airflow._shared.module_loading import qualname
143 from airflow.serialization.serialized_objects import _get_registered_priority_weight_strategy
144
145 if priority_weight_strategy is None:
146 return _AbsolutePriorityWeightStrategy()
147
148 if isinstance(priority_weight_strategy, str):
149 if priority_weight_strategy in airflow_priority_weight_strategies:
150 return airflow_priority_weight_strategies[priority_weight_strategy]()
151 priority_weight_strategy_class = priority_weight_strategy
152 else:
153 priority_weight_strategy_class = qualname(priority_weight_strategy)
154 loaded_priority_weight_strategy = _get_registered_priority_weight_strategy(priority_weight_strategy_class)
155 if loaded_priority_weight_strategy is None:
156 raise ValueError(f"Unknown priority strategy {priority_weight_strategy_class}")
157 return loaded_priority_weight_strategy()