Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/airflow/sdk/definitions/edges.py: 30%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

76 statements  

1# Licensed to the Apache Software Foundation (ASF) under one 

2# or more contributor license agreements. See the NOTICE file 

3# distributed with this work for additional information 

4# regarding copyright ownership. The ASF licenses this file 

5# to you under the Apache License, Version 2.0 (the 

6# "License"); you may not use this file except in compliance 

7# with the License. You may obtain a copy of the License at 

8# 

9# http://www.apache.org/licenses/LICENSE-2.0 

10# 

11# Unless required by applicable law or agreed to in writing, 

12# software distributed under the License is distributed on an 

13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 

14# KIND, either express or implied. See the License for the 

15# specific language governing permissions and limitations 

16# under the License. 

17from __future__ import annotations 

18 

19from collections.abc import Sequence 

20from typing import TYPE_CHECKING, TypedDict 

21 

22from airflow.sdk.definitions._internal.mixins import DependencyMixin 

23 

24if TYPE_CHECKING: 

25 from airflow.sdk.definitions.dag import DAG 

26 

27 

28class EdgeModifier(DependencyMixin): 

29 """ 

30 Class that represents edge information to be added between two tasks/operators. 

31 

32 Has shorthand factory functions, like Label("hooray"). 

33 

34 Current implementation supports 

35 t1 >> Label("Success route") >> t2 

36 t2 << Label("Success route") << t2 

37 

38 Note that due to the potential for use in either direction, this waits 

39 to make the actual connection between both sides until both are declared, 

40 and will do so progressively if multiple ups/downs are added. 

41 

42 This and EdgeInfo are related - an EdgeModifier is the Python object you 

43 use to add information to (potentially multiple) edges, and EdgeInfo 

44 is the representation of the information for one specific edge. 

45 """ 

46 

47 def __init__(self, label: str | None = None): 

48 self.label = label 

49 self._upstream: list[DependencyMixin] = [] 

50 self._downstream: list[DependencyMixin] = [] 

51 

52 @property 

53 def roots(self): 

54 return self._downstream 

55 

56 @property 

57 def leaves(self): 

58 return self._upstream 

59 

60 @staticmethod 

61 def _make_list( 

62 item_or_list: DependencyMixin | Sequence[DependencyMixin], 

63 ) -> Sequence[DependencyMixin]: 

64 if not isinstance(item_or_list, Sequence): 

65 return [item_or_list] 

66 return item_or_list 

67 

68 def _save_nodes( 

69 self, 

70 nodes: DependencyMixin | Sequence[DependencyMixin], 

71 stream: list[DependencyMixin], 

72 ): 

73 from airflow.sdk.definitions._internal.node import DAGNode 

74 from airflow.sdk.definitions.taskgroup import TaskGroup 

75 from airflow.sdk.definitions.xcom_arg import XComArg 

76 

77 for node in self._make_list(nodes): 

78 if isinstance(node, (TaskGroup, XComArg, DAGNode)): 

79 stream.append(node) 

80 else: 

81 raise TypeError( 

82 f"Cannot use edge labels with {type(node).__name__}, only tasks, XComArg or TaskGroups" 

83 ) 

84 

85 def _convert_streams_to_task_groups(self): 

86 """ 

87 Convert a node to a TaskGroup or leave it as a DAGNode. 

88 

89 Requires both self._upstream and self._downstream. 

90 

91 To do this, we keep a set of group_ids seen among the streams. If we find that 

92 the nodes are from the same TaskGroup, we will leave them as DAGNodes and not 

93 convert them to TaskGroups 

94 """ 

95 from airflow.sdk.definitions._internal.node import DAGNode 

96 from airflow.sdk.definitions.taskgroup import TaskGroup 

97 from airflow.sdk.definitions.xcom_arg import XComArg 

98 

99 group_ids = set() 

100 for node in [*self._upstream, *self._downstream]: 

101 if isinstance(node, DAGNode) and node.task_group: 

102 if node.task_group.is_root: 

103 group_ids.add("root") 

104 else: 

105 group_ids.add(node.task_group.group_id) 

106 elif isinstance(node, TaskGroup): 

107 group_ids.add(node.group_id) 

108 elif isinstance(node, XComArg): 

109 if isinstance(node.operator, DAGNode) and node.operator.task_group: 

110 if node.operator.task_group.is_root: 

111 group_ids.add("root") 

112 else: 

113 group_ids.add(node.operator.task_group.group_id) 

114 

115 # If all nodes originate from the same TaskGroup, we will not convert them 

116 if len(group_ids) != 1: 

117 self._upstream = self._convert_stream_to_task_groups(self._upstream) 

118 self._downstream = self._convert_stream_to_task_groups(self._downstream) 

119 

120 def _convert_stream_to_task_groups(self, stream: Sequence[DependencyMixin]) -> Sequence[DependencyMixin]: 

121 from airflow.sdk.definitions._internal.node import DAGNode 

122 

123 return [ 

124 node.task_group 

125 if isinstance(node, DAGNode) and node.task_group and not node.task_group.is_root 

126 else node 

127 for node in stream 

128 ] 

129 

130 def set_upstream( 

131 self, 

132 other: DependencyMixin | Sequence[DependencyMixin], 

133 edge_modifier: EdgeModifier | None = None, 

134 ): 

135 """ 

136 Set the given task/list onto the upstream attribute, then attempt to resolve the relationship. 

137 

138 Providing this also provides << via DependencyMixin. 

139 """ 

140 self._save_nodes(other, self._upstream) 

141 if self._upstream and self._downstream: 

142 # Convert _upstream and _downstream to task_groups only after both are set 

143 self._convert_streams_to_task_groups() 

144 for node in self._downstream: 

145 node.set_upstream(other, edge_modifier=self) 

146 

147 def set_downstream( 

148 self, 

149 other: DependencyMixin | Sequence[DependencyMixin], 

150 edge_modifier: EdgeModifier | None = None, 

151 ): 

152 """ 

153 Set the given task/list onto the downstream attribute, then attempt to resolve the relationship. 

154 

155 Providing this also provides >> via DependencyMixin. 

156 """ 

157 self._save_nodes(other, self._downstream) 

158 if self._upstream and self._downstream: 

159 # Convert _upstream and _downstream to task_groups only after both are set 

160 self._convert_streams_to_task_groups() 

161 for node in self._upstream: 

162 node.set_downstream(other, edge_modifier=self) 

163 

164 def update_relative( 

165 self, 

166 other: DependencyMixin, 

167 upstream: bool = True, 

168 edge_modifier: EdgeModifier | None = None, 

169 ) -> None: 

170 """Update relative if we're not the "main" side of a relationship; still run the same logic.""" 

171 if upstream: 

172 self.set_upstream(other) 

173 else: 

174 self.set_downstream(other) 

175 

176 def add_edge_info(self, dag: DAG, upstream_id: str, downstream_id: str): 

177 """ 

178 Add or update task info on the Dag for this specific pair of tasks. 

179 

180 Called either from our relationship trigger methods above, or directly 

181 by set_upstream/set_downstream in operators. 

182 """ 

183 dag.set_edge_info(upstream_id, downstream_id, {"label": self.label}) 

184 

185 

186# Factory functions 

187def Label(label: str): 

188 """Create an EdgeModifier that sets a human-readable label on the edge.""" 

189 return EdgeModifier(label=label) 

190 

191 

192class EdgeInfoType(TypedDict): 

193 """Extra metadata that the Dag can store about an edge, usually generated from an EdgeModifier.""" 

194 

195 label: str | None