Coverage for /pythoncovmergedfiles/medio/medio/src/airflow/airflow/task/priority_strategy.py: 49%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18"""Priority weight strategies for task scheduling."""
20from __future__ import annotations
22from abc import ABC, abstractmethod
23from typing import TYPE_CHECKING, Any
25from airflow.exceptions import AirflowException
27if TYPE_CHECKING:
28 from airflow.models.taskinstance import TaskInstance
31class PriorityWeightStrategy(ABC):
32 """Priority weight strategy interface.
34 This feature is experimental and subject to change at any time.
36 Currently, we don't serialize the priority weight strategy parameters. This means that
37 the priority weight strategy must be stateless, but you can add class attributes, and
38 create multiple subclasses with different attributes values if you need to create
39 different versions of the same strategy.
40 """
42 @abstractmethod
43 def get_weight(self, ti: TaskInstance):
44 """Get the priority weight of a task."""
45 ...
47 @classmethod
48 def deserialize(cls, data: dict[str, Any]) -> PriorityWeightStrategy:
49 """Deserialize a priority weight strategy from data.
51 This is called when a serialized DAG is deserialized. ``data`` will be whatever
52 was returned by ``serialize`` during DAG serialization. The default
53 implementation constructs the priority weight strategy without any arguments.
54 """
55 return cls(**data) # type: ignore[call-arg]
57 def serialize(self) -> dict[str, Any]:
58 """Serialize the priority weight strategy for JSON encoding.
60 This is called during DAG serialization to store priority weight strategy information
61 in the database. This should return a JSON-serializable dict that will be fed into
62 ``deserialize`` when the DAG is deserialized. The default implementation returns
63 an empty dict.
64 """
65 return {}
67 def __eq__(self, other: object) -> bool:
68 """Equality comparison."""
69 if not isinstance(other, type(self)):
70 return False
71 return self.serialize() == other.serialize()
74class _AbsolutePriorityWeightStrategy(PriorityWeightStrategy):
75 """Priority weight strategy that uses the task's priority weight directly."""
77 def get_weight(self, ti: TaskInstance):
78 if TYPE_CHECKING:
79 assert ti.task
80 return ti.task.priority_weight
83class _DownstreamPriorityWeightStrategy(PriorityWeightStrategy):
84 """Priority weight strategy that uses the sum of the priority weights of all downstream tasks."""
86 def get_weight(self, ti: TaskInstance) -> int:
87 if TYPE_CHECKING:
88 assert ti.task
89 dag = ti.task.get_dag()
90 if dag is None:
91 return ti.task.priority_weight
92 return ti.task.priority_weight + sum(
93 dag.task_dict[task_id].priority_weight
94 for task_id in ti.task.get_flat_relative_ids(upstream=False)
95 )
98class _UpstreamPriorityWeightStrategy(PriorityWeightStrategy):
99 """Priority weight strategy that uses the sum of the priority weights of all upstream tasks."""
101 def get_weight(self, ti: TaskInstance):
102 if TYPE_CHECKING:
103 assert ti.task
104 dag = ti.task.get_dag()
105 if dag is None:
106 return ti.task.priority_weight
107 return ti.task.priority_weight + sum(
108 dag.task_dict[task_id].priority_weight for task_id in ti.task.get_flat_relative_ids(upstream=True)
109 )
112airflow_priority_weight_strategies: dict[str, type[PriorityWeightStrategy]] = {
113 "absolute": _AbsolutePriorityWeightStrategy,
114 "downstream": _DownstreamPriorityWeightStrategy,
115 "upstream": _UpstreamPriorityWeightStrategy,
116}
119airflow_priority_weight_strategies_classes = {
120 cls: name for name, cls in airflow_priority_weight_strategies.items()
121}
124def validate_and_load_priority_weight_strategy(
125 priority_weight_strategy: str | PriorityWeightStrategy | None,
126) -> PriorityWeightStrategy:
127 """Validate and load a priority weight strategy.
129 Returns the priority weight strategy if it is valid, otherwise raises an exception.
131 :param priority_weight_strategy: The priority weight strategy to validate and load.
133 :meta private:
134 """
135 from airflow.serialization.serialized_objects import _get_registered_priority_weight_strategy
136 from airflow.utils.module_loading import qualname
138 if priority_weight_strategy is None:
139 return _AbsolutePriorityWeightStrategy()
141 if isinstance(priority_weight_strategy, str):
142 if priority_weight_strategy in airflow_priority_weight_strategies:
143 return airflow_priority_weight_strategies[priority_weight_strategy]()
144 priority_weight_strategy_class = priority_weight_strategy
145 else:
146 priority_weight_strategy_class = qualname(priority_weight_strategy)
147 loaded_priority_weight_strategy = _get_registered_priority_weight_strategy(priority_weight_strategy_class)
148 if loaded_priority_weight_strategy is None:
149 raise AirflowException(f"Unknown priority strategy {priority_weight_strategy_class}")
150 return loaded_priority_weight_strategy()