Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/airflow/utils/edgemodifier.py: 30%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

66 statements  

1# Licensed to the Apache Software Foundation (ASF) under one 

2# or more contributor license agreements. See the NOTICE file 

3# distributed with this work for additional information 

4# regarding copyright ownership. The ASF licenses this file 

5# to you under the Apache License, Version 2.0 (the 

6# "License"); you may not use this file except in compliance 

7# with the License. You may obtain a copy of the License at 

8# 

9# http://www.apache.org/licenses/LICENSE-2.0 

10# 

11# Unless required by applicable law or agreed to in writing, 

12# software distributed under the License is distributed on an 

13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 

14# KIND, either express or implied. See the License for the 

15# specific language governing permissions and limitations 

16# under the License. 

17from __future__ import annotations 

18 

19from typing import Sequence 

20 

21from airflow.models.taskmixin import DAGNode, DependencyMixin 

22from airflow.utils.task_group import TaskGroup 

23 

24 

25class EdgeModifier(DependencyMixin): 

26 """ 

27 Class that represents edge information to be added between two tasks/operators. 

28 

29 Has shorthand factory functions, like Label("hooray"). 

30 

31 Current implementation supports 

32 t1 >> Label("Success route") >> t2 

33 t2 << Label("Success route") << t2 

34 

35 Note that due to the potential for use in either direction, this waits 

36 to make the actual connection between both sides until both are declared, 

37 and will do so progressively if multiple ups/downs are added. 

38 

39 This and EdgeInfo are related - an EdgeModifier is the Python object you 

40 use to add information to (potentially multiple) edges, and EdgeInfo 

41 is the representation of the information for one specific edge. 

42 """ 

43 

44 def __init__(self, label: str | None = None): 

45 self.label = label 

46 self._upstream: list[DependencyMixin] = [] 

47 self._downstream: list[DependencyMixin] = [] 

48 

49 @property 

50 def roots(self): 

51 return self._downstream 

52 

53 @property 

54 def leaves(self): 

55 return self._upstream 

56 

57 @staticmethod 

58 def _make_list(item_or_list: DependencyMixin | Sequence[DependencyMixin]) -> Sequence[DependencyMixin]: 

59 if not isinstance(item_or_list, Sequence): 

60 return [item_or_list] 

61 return item_or_list 

62 

63 def _save_nodes( 

64 self, 

65 nodes: DependencyMixin | Sequence[DependencyMixin], 

66 stream: list[DependencyMixin], 

67 ): 

68 from airflow.models.xcom_arg import XComArg 

69 

70 for node in self._make_list(nodes): 

71 if isinstance(node, (TaskGroup, XComArg, DAGNode)): 

72 stream.append(node) 

73 else: 

74 raise TypeError( 

75 f"Cannot use edge labels with {type(node).__name__}, " 

76 f"only tasks, XComArg or TaskGroups" 

77 ) 

78 

79 def _convert_streams_to_task_groups(self): 

80 """ 

81 Convert a node to a TaskGroup or leave it as a DAGNode. 

82 

83 Requires both self._upstream and self._downstream. 

84 

85 To do this, we keep a set of group_ids seen among the streams. If we find that 

86 the nodes are from the same TaskGroup, we will leave them as DAGNodes and not 

87 convert them to TaskGroups 

88 """ 

89 from airflow.models.xcom_arg import XComArg 

90 

91 group_ids = set() 

92 for node in [*self._upstream, *self._downstream]: 

93 if isinstance(node, DAGNode) and node.task_group: 

94 if node.task_group.is_root: 

95 group_ids.add("root") 

96 else: 

97 group_ids.add(node.task_group.group_id) 

98 elif isinstance(node, TaskGroup): 

99 group_ids.add(node.group_id) 

100 elif isinstance(node, XComArg): 

101 if isinstance(node.operator, DAGNode) and node.operator.task_group: 

102 if node.operator.task_group.is_root: 

103 group_ids.add("root") 

104 else: 

105 group_ids.add(node.operator.task_group.group_id) 

106 

107 # If all nodes originate from the same TaskGroup, we will not convert them 

108 if len(group_ids) != 1: 

109 self._upstream = self._convert_stream_to_task_groups(self._upstream) 

110 self._downstream = self._convert_stream_to_task_groups(self._downstream) 

111 

112 def _convert_stream_to_task_groups(self, stream: Sequence[DependencyMixin]) -> Sequence[DependencyMixin]: 

113 return [ 

114 node.task_group 

115 if isinstance(node, DAGNode) and node.task_group and not node.task_group.is_root 

116 else node 

117 for node in stream 

118 ] 

119 

120 def set_upstream( 

121 self, 

122 other: DependencyMixin | Sequence[DependencyMixin], 

123 edge_modifier: EdgeModifier | None = None, 

124 ): 

125 """ 

126 Set the given task/list onto the upstream attribute, then attempt to resolve the relationship. 

127 

128 Providing this also provides << via DependencyMixin. 

129 """ 

130 self._save_nodes(other, self._upstream) 

131 if self._upstream and self._downstream: 

132 # Convert _upstream and _downstream to task_groups only after both are set 

133 self._convert_streams_to_task_groups() 

134 for node in self._downstream: 

135 node.set_upstream(other, edge_modifier=self) 

136 

137 def set_downstream( 

138 self, 

139 other: DependencyMixin | Sequence[DependencyMixin], 

140 edge_modifier: EdgeModifier | None = None, 

141 ): 

142 """ 

143 Set the given task/list onto the downstream attribute, then attempt to resolve the relationship. 

144 

145 Providing this also provides >> via DependencyMixin. 

146 """ 

147 self._save_nodes(other, self._downstream) 

148 if self._upstream and self._downstream: 

149 # Convert _upstream and _downstream to task_groups only after both are set 

150 self._convert_streams_to_task_groups() 

151 for node in self._upstream: 

152 node.set_downstream(other, edge_modifier=self) 

153 

154 def update_relative( 

155 self, other: DependencyMixin, upstream: bool = True, edge_modifier: EdgeModifier | None = None 

156 ) -> None: 

157 """Update relative if we're not the "main" side of a relationship; still run the same logic.""" 

158 if upstream: 

159 self.set_upstream(other) 

160 else: 

161 self.set_downstream(other) 

162 

163 def add_edge_info(self, dag, upstream_id: str, downstream_id: str): 

164 """ 

165 Add or update task info on the DAG for this specific pair of tasks. 

166 

167 Called either from our relationship trigger methods above, or directly 

168 by set_upstream/set_downstream in operators. 

169 """ 

170 dag.set_edge_info(upstream_id, downstream_id, {"label": self.label}) 

171 

172 

173# Factory functions 

174def Label(label: str): 

175 """Create an EdgeModifier that sets a human-readable label on the edge.""" 

176 return EdgeModifier(label=label)