Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/airflow/task/priority_strategy.py: 49%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

57 statements  

1# 

2# Licensed to the Apache Software Foundation (ASF) under one 

3# or more contributor license agreements. See the NOTICE file 

4# distributed with this work for additional information 

5# regarding copyright ownership. The ASF licenses this file 

6# to you under the Apache License, Version 2.0 (the 

7# "License"); you may not use this file except in compliance 

8# with the License. You may obtain a copy of the License at 

9# 

10# http://www.apache.org/licenses/LICENSE-2.0 

11# 

12# Unless required by applicable law or agreed to in writing, 

13# software distributed under the License is distributed on an 

14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 

15# KIND, either express or implied. See the License for the 

16# specific language governing permissions and limitations 

17# under the License. 

18"""Priority weight strategies for task scheduling.""" 

19 

20from __future__ import annotations 

21 

22from abc import ABC, abstractmethod 

23from typing import TYPE_CHECKING, Any 

24 

25from airflow.exceptions import AirflowException 

26 

27if TYPE_CHECKING: 

28 from airflow.models.taskinstance import TaskInstance 

29 

30 

31class PriorityWeightStrategy(ABC): 

32 """Priority weight strategy interface. 

33 

34 This feature is experimental and subject to change at any time. 

35 

36 Currently, we don't serialize the priority weight strategy parameters. This means that 

37 the priority weight strategy must be stateless, but you can add class attributes, and 

38 create multiple subclasses with different attributes values if you need to create 

39 different versions of the same strategy. 

40 """ 

41 

42 @abstractmethod 

43 def get_weight(self, ti: TaskInstance): 

44 """Get the priority weight of a task.""" 

45 ... 

46 

47 @classmethod 

48 def deserialize(cls, data: dict[str, Any]) -> PriorityWeightStrategy: 

49 """Deserialize a priority weight strategy from data. 

50 

51 This is called when a serialized DAG is deserialized. ``data`` will be whatever 

52 was returned by ``serialize`` during DAG serialization. The default 

53 implementation constructs the priority weight strategy without any arguments. 

54 """ 

55 return cls(**data) # type: ignore[call-arg] 

56 

57 def serialize(self) -> dict[str, Any]: 

58 """Serialize the priority weight strategy for JSON encoding. 

59 

60 This is called during DAG serialization to store priority weight strategy information 

61 in the database. This should return a JSON-serializable dict that will be fed into 

62 ``deserialize`` when the DAG is deserialized. The default implementation returns 

63 an empty dict. 

64 """ 

65 return {} 

66 

67 def __eq__(self, other: object) -> bool: 

68 """Equality comparison.""" 

69 if not isinstance(other, type(self)): 

70 return False 

71 return self.serialize() == other.serialize() 

72 

73 

74class _AbsolutePriorityWeightStrategy(PriorityWeightStrategy): 

75 """Priority weight strategy that uses the task's priority weight directly.""" 

76 

77 def get_weight(self, ti: TaskInstance): 

78 if TYPE_CHECKING: 

79 assert ti.task 

80 return ti.task.priority_weight 

81 

82 

83class _DownstreamPriorityWeightStrategy(PriorityWeightStrategy): 

84 """Priority weight strategy that uses the sum of the priority weights of all downstream tasks.""" 

85 

86 def get_weight(self, ti: TaskInstance) -> int: 

87 if TYPE_CHECKING: 

88 assert ti.task 

89 dag = ti.task.get_dag() 

90 if dag is None: 

91 return ti.task.priority_weight 

92 return ti.task.priority_weight + sum( 

93 dag.task_dict[task_id].priority_weight 

94 for task_id in ti.task.get_flat_relative_ids(upstream=False) 

95 ) 

96 

97 

98class _UpstreamPriorityWeightStrategy(PriorityWeightStrategy): 

99 """Priority weight strategy that uses the sum of the priority weights of all upstream tasks.""" 

100 

101 def get_weight(self, ti: TaskInstance): 

102 if TYPE_CHECKING: 

103 assert ti.task 

104 dag = ti.task.get_dag() 

105 if dag is None: 

106 return ti.task.priority_weight 

107 return ti.task.priority_weight + sum( 

108 dag.task_dict[task_id].priority_weight for task_id in ti.task.get_flat_relative_ids(upstream=True) 

109 ) 

110 

111 

112airflow_priority_weight_strategies: dict[str, type[PriorityWeightStrategy]] = { 

113 "absolute": _AbsolutePriorityWeightStrategy, 

114 "downstream": _DownstreamPriorityWeightStrategy, 

115 "upstream": _UpstreamPriorityWeightStrategy, 

116} 

117 

118 

119airflow_priority_weight_strategies_classes = { 

120 cls: name for name, cls in airflow_priority_weight_strategies.items() 

121} 

122 

123 

124def validate_and_load_priority_weight_strategy( 

125 priority_weight_strategy: str | PriorityWeightStrategy | None, 

126) -> PriorityWeightStrategy: 

127 """Validate and load a priority weight strategy. 

128 

129 Returns the priority weight strategy if it is valid, otherwise raises an exception. 

130 

131 :param priority_weight_strategy: The priority weight strategy to validate and load. 

132 

133 :meta private: 

134 """ 

135 from airflow.serialization.serialized_objects import _get_registered_priority_weight_strategy 

136 from airflow.utils.module_loading import qualname 

137 

138 if priority_weight_strategy is None: 

139 return _AbsolutePriorityWeightStrategy() 

140 

141 if isinstance(priority_weight_strategy, str): 

142 if priority_weight_strategy in airflow_priority_weight_strategies: 

143 return airflow_priority_weight_strategies[priority_weight_strategy]() 

144 priority_weight_strategy_class = priority_weight_strategy 

145 else: 

146 priority_weight_strategy_class = qualname(priority_weight_strategy) 

147 loaded_priority_weight_strategy = _get_registered_priority_weight_strategy(priority_weight_strategy_class) 

148 if loaded_priority_weight_strategy is None: 

149 raise AirflowException(f"Unknown priority strategy {priority_weight_strategy_class}") 

150 return loaded_priority_weight_strategy()