Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/airflow/task/priority_strategy.py: 38%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

60 statements  

1# 

2# Licensed to the Apache Software Foundation (ASF) under one 

3# or more contributor license agreements. See the NOTICE file 

4# distributed with this work for additional information 

5# regarding copyright ownership. The ASF licenses this file 

6# to you under the Apache License, Version 2.0 (the 

7# "License"); you may not use this file except in compliance 

8# with the License. You may obtain a copy of the License at 

9# 

10# http://www.apache.org/licenses/LICENSE-2.0 

11# 

12# Unless required by applicable law or agreed to in writing, 

13# software distributed under the License is distributed on an 

14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 

15# KIND, either express or implied. See the License for the 

16# specific language governing permissions and limitations 

17# under the License. 

18"""Priority weight strategies for task scheduling.""" 

19 

20from __future__ import annotations 

21 

22from abc import ABC, abstractmethod 

23from typing import TYPE_CHECKING, Any 

24 

25from airflow.task.weight_rule import WeightRule 

26 

27if TYPE_CHECKING: 

28 from airflow.models.taskinstance import TaskInstance 

29 

30 

31class PriorityWeightStrategy(ABC): 

32 """ 

33 Priority weight strategy interface. 

34 

35 This feature is experimental and subject to change at any time. 

36 

37 Currently, we don't serialize the priority weight strategy parameters. This means that 

38 the priority weight strategy must be stateless, but you can add class attributes, and 

39 create multiple subclasses with different attributes values if you need to create 

40 different versions of the same strategy. 

41 """ 

42 

43 @abstractmethod 

44 def get_weight(self, ti: TaskInstance): 

45 """Get the priority weight of a task.""" 

46 ... 

47 

48 @classmethod 

49 def deserialize(cls, data: dict[str, Any]) -> PriorityWeightStrategy: 

50 """ 

51 Deserialize a priority weight strategy from data. 

52 

53 This is called when a serialized DAG is deserialized. ``data`` will be whatever 

54 was returned by ``serialize`` during DAG serialization. The default 

55 implementation constructs the priority weight strategy without any arguments. 

56 """ 

57 return cls(**data) 

58 

59 def serialize(self) -> dict[str, Any]: 

60 """ 

61 Serialize the priority weight strategy for JSON encoding. 

62 

63 This is called during DAG serialization to store priority weight strategy information 

64 in the database. This should return a JSON-serializable dict that will be fed into 

65 ``deserialize`` when the DAG is deserialized. The default implementation returns 

66 an empty dict. 

67 """ 

68 return {} 

69 

70 def __eq__(self, other: object) -> bool: 

71 """Equality comparison.""" 

72 if not isinstance(other, type(self)): 

73 return False 

74 return self.serialize() == other.serialize() 

75 

76 def __hash__(self): 

77 return hash(self.serialize()) 

78 

79 

80class _AbsolutePriorityWeightStrategy(PriorityWeightStrategy): 

81 """Priority weight strategy that uses the task's priority weight directly.""" 

82 

83 def get_weight(self, ti: TaskInstance): 

84 if TYPE_CHECKING: 

85 assert ti.task 

86 return ti.task.priority_weight 

87 

88 

89class _DownstreamPriorityWeightStrategy(PriorityWeightStrategy): 

90 """Priority weight strategy that uses the sum of the priority weights of all downstream tasks.""" 

91 

92 def get_weight(self, ti: TaskInstance) -> int: 

93 if TYPE_CHECKING: 

94 assert ti.task 

95 dag = ti.task.get_dag() 

96 if dag is None: 

97 return ti.task.priority_weight 

98 return ti.task.priority_weight + sum( 

99 dag.task_dict[task_id].priority_weight 

100 for task_id in ti.task.get_flat_relative_ids(upstream=False) 

101 ) 

102 

103 

104class _UpstreamPriorityWeightStrategy(PriorityWeightStrategy): 

105 """Priority weight strategy that uses the sum of the priority weights of all upstream tasks.""" 

106 

107 def get_weight(self, ti: TaskInstance): 

108 if TYPE_CHECKING: 

109 assert ti.task 

110 dag = ti.task.get_dag() 

111 if dag is None: 

112 return ti.task.priority_weight 

113 return ti.task.priority_weight + sum( 

114 dag.task_dict[task_id].priority_weight for task_id in ti.task.get_flat_relative_ids(upstream=True) 

115 ) 

116 

117 

118airflow_priority_weight_strategies: dict[str, type[PriorityWeightStrategy]] = { 

119 WeightRule.ABSOLUTE: _AbsolutePriorityWeightStrategy, 

120 WeightRule.DOWNSTREAM: _DownstreamPriorityWeightStrategy, 

121 WeightRule.UPSTREAM: _UpstreamPriorityWeightStrategy, 

122} 

123 

124 

125airflow_priority_weight_strategies_classes = { 

126 cls: name for name, cls in airflow_priority_weight_strategies.items() 

127} 

128 

129 

130def validate_and_load_priority_weight_strategy( 

131 priority_weight_strategy: str | PriorityWeightStrategy | None, 

132) -> PriorityWeightStrategy: 

133 """ 

134 Validate and load a priority weight strategy. 

135 

136 Returns the priority weight strategy if it is valid, otherwise raises an exception. 

137 

138 :param priority_weight_strategy: The priority weight strategy to validate and load. 

139 

140 :meta private: 

141 """ 

142 from airflow._shared.module_loading import qualname 

143 from airflow.serialization.serialized_objects import _get_registered_priority_weight_strategy 

144 

145 if priority_weight_strategy is None: 

146 return _AbsolutePriorityWeightStrategy() 

147 

148 if isinstance(priority_weight_strategy, str): 

149 if priority_weight_strategy in airflow_priority_weight_strategies: 

150 return airflow_priority_weight_strategies[priority_weight_strategy]() 

151 priority_weight_strategy_class = priority_weight_strategy 

152 else: 

153 priority_weight_strategy_class = qualname(priority_weight_strategy) 

154 loaded_priority_weight_strategy = _get_registered_priority_weight_strategy(priority_weight_strategy_class) 

155 if loaded_priority_weight_strategy is None: 

156 raise ValueError(f"Unknown priority strategy {priority_weight_strategy_class}") 

157 return loaded_priority_weight_strategy()