Coverage for /pythoncovmergedfiles/medio/medio/src/airflow/build/lib/airflow/utils/edgemodifier.py: 30%

66 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-06-07 06:35 +0000

1# Licensed to the Apache Software Foundation (ASF) under one 

2# or more contributor license agreements. See the NOTICE file 

3# distributed with this work for additional information 

4# regarding copyright ownership. The ASF licenses this file 

5# to you under the Apache License, Version 2.0 (the 

6# "License"); you may not use this file except in compliance 

7# with the License. You may obtain a copy of the License at 

8# 

9# http://www.apache.org/licenses/LICENSE-2.0 

10# 

11# Unless required by applicable law or agreed to in writing, 

12# software distributed under the License is distributed on an 

13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 

14# KIND, either express or implied. See the License for the 

15# specific language governing permissions and limitations 

16# under the License. 

17from __future__ import annotations 

18 

19from typing import Sequence 

20 

21from airflow.models.taskmixin import DAGNode, DependencyMixin 

22from airflow.utils.task_group import TaskGroup 

23 

24 

25class EdgeModifier(DependencyMixin): 

26 """ 

27 Class that represents edge information to be added between two 

28 tasks/operators. Has shorthand factory functions, like Label("hooray"). 

29 

30 Current implementation supports 

31 t1 >> Label("Success route") >> t2 

32 t2 << Label("Success route") << t2 

33 

34 Note that due to the potential for use in either direction, this waits 

35 to make the actual connection between both sides until both are declared, 

36 and will do so progressively if multiple ups/downs are added. 

37 

38 This and EdgeInfo are related - an EdgeModifier is the Python object you 

39 use to add information to (potentially multiple) edges, and EdgeInfo 

40 is the representation of the information for one specific edge. 

41 """ 

42 

43 def __init__(self, label: str | None = None): 

44 self.label = label 

45 self._upstream: list[DependencyMixin] = [] 

46 self._downstream: list[DependencyMixin] = [] 

47 

48 @property 

49 def roots(self): 

50 return self._downstream 

51 

52 @property 

53 def leaves(self): 

54 return self._upstream 

55 

56 @staticmethod 

57 def _make_list(item_or_list: DependencyMixin | Sequence[DependencyMixin]) -> Sequence[DependencyMixin]: 

58 if not isinstance(item_or_list, Sequence): 

59 return [item_or_list] 

60 return item_or_list 

61 

62 def _save_nodes( 

63 self, 

64 nodes: DependencyMixin | Sequence[DependencyMixin], 

65 stream: list[DependencyMixin], 

66 ): 

67 from airflow.models.xcom_arg import XComArg 

68 

69 for node in self._make_list(nodes): 

70 if isinstance(node, (TaskGroup, XComArg, DAGNode)): 

71 stream.append(node) 

72 else: 

73 raise TypeError( 

74 f"Cannot use edge labels with {type(node).__name__}, " 

75 f"only tasks, XComArg or TaskGroups" 

76 ) 

77 

78 def _convert_streams_to_task_groups(self): 

79 """ 

80 Both self._upstream and self._downstream are required to determine if 

81 we should convert a node to a TaskGroup or leave it as a DAGNode. 

82 

83 To do this, we keep a set of group_ids seen among the streams. If we find that 

84 the nodes are from the same TaskGroup, we will leave them as DAGNodes and not 

85 convert them to TaskGroups 

86 """ 

87 from airflow.models.xcom_arg import XComArg 

88 

89 group_ids = set() 

90 for node in [*self._upstream, *self._downstream]: 

91 if isinstance(node, DAGNode) and node.task_group: 

92 if node.task_group.is_root: 

93 group_ids.add("root") 

94 else: 

95 group_ids.add(node.task_group.group_id) 

96 elif isinstance(node, TaskGroup): 

97 group_ids.add(node.group_id) 

98 elif isinstance(node, XComArg): 

99 if isinstance(node.operator, DAGNode) and node.operator.task_group: 

100 if node.operator.task_group.is_root: 

101 group_ids.add("root") 

102 else: 

103 group_ids.add(node.operator.task_group.group_id) 

104 

105 # If all nodes originate from the same TaskGroup, we will not convert them 

106 if len(group_ids) != 1: 

107 self._upstream = self._convert_stream_to_task_groups(self._upstream) 

108 self._downstream = self._convert_stream_to_task_groups(self._downstream) 

109 

110 def _convert_stream_to_task_groups(self, stream: Sequence[DependencyMixin]) -> Sequence[DependencyMixin]: 

111 return [ 

112 node.task_group 

113 if isinstance(node, DAGNode) and node.task_group and not node.task_group.is_root 

114 else node 

115 for node in stream 

116 ] 

117 

118 def set_upstream( 

119 self, 

120 other: DependencyMixin | Sequence[DependencyMixin], 

121 edge_modifier: EdgeModifier | None = None, 

122 ): 

123 """ 

124 Sets the given task/list onto the upstream attribute, and then checks if 

125 we have both sides so we can resolve the relationship. 

126 

127 Providing this also provides << via DependencyMixin. 

128 """ 

129 self._save_nodes(other, self._upstream) 

130 if self._upstream and self._downstream: 

131 # Convert _upstream and _downstream to task_groups only after both are set 

132 self._convert_streams_to_task_groups() 

133 for node in self._downstream: 

134 node.set_upstream(other, edge_modifier=self) 

135 

136 def set_downstream( 

137 self, 

138 other: DependencyMixin | Sequence[DependencyMixin], 

139 edge_modifier: EdgeModifier | None = None, 

140 ): 

141 """ 

142 Sets the given task/list onto the downstream attribute, and then checks if 

143 we have both sides so we can resolve the relationship. 

144 

145 Providing this also provides >> via DependencyMixin. 

146 """ 

147 self._save_nodes(other, self._downstream) 

148 if self._upstream and self._downstream: 

149 # Convert _upstream and _downstream to task_groups only after both are set 

150 self._convert_streams_to_task_groups() 

151 for node in self._upstream: 

152 node.set_downstream(other, edge_modifier=self) 

153 

154 def update_relative( 

155 self, other: DependencyMixin, upstream: bool = True, edge_modifier: EdgeModifier | None = None 

156 ) -> None: 

157 """ 

158 Called if we're not the "main" side of a relationship; we still run the 

159 same logic, though. 

160 """ 

161 if upstream: 

162 self.set_upstream(other) 

163 else: 

164 self.set_downstream(other) 

165 

166 def add_edge_info(self, dag, upstream_id: str, downstream_id: str): 

167 """ 

168 Adds or updates task info on the DAG for this specific pair of tasks. 

169 

170 Called either from our relationship trigger methods above, or directly 

171 by set_upstream/set_downstream in operators. 

172 """ 

173 dag.set_edge_info(upstream_id, downstream_id, {"label": self.label}) 

174 

175 

176# Factory functions 

177def Label(label: str): 

178 """Creates an EdgeModifier that sets a human-readable label on the edge.""" 

179 return EdgeModifier(label=label)