Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/airflow/sdk/_shared/dagnode/node.py: 34%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

111 statements  

1# Licensed to the Apache Software Foundation (ASF) under one 

2# or more contributor license agreements. See the NOTICE file 

3# distributed with this work for additional information 

4# regarding copyright ownership. The ASF licenses this file 

5# to you under the Apache License, Version 2.0 (the 

6# "License"); you may not use this file except in compliance 

7# with the License. You may obtain a copy of the License at 

8# 

9# http://www.apache.org/licenses/LICENSE-2.0 

10# 

11# Unless required by applicable law or agreed to in writing, 

12# software distributed under the License is distributed on an 

13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 

14# KIND, either express or implied. See the License for the 

15# specific language governing permissions and limitations 

16# under the License. 

17 

18from __future__ import annotations 

19 

20from typing import TYPE_CHECKING, Generic, TypeVar 

21 

22import structlog 

23 

24if TYPE_CHECKING: 

25 from collections.abc import Collection, Iterable 

26 

27 from ..logging.types import Logger 

28 

29Dag = TypeVar("Dag") 

30Task = TypeVar("Task") 

31TaskGroup = TypeVar("TaskGroup") 

32 

33 

34class GenericDAGNode(Generic[Dag, Task, TaskGroup]): 

35 """ 

36 Generic class for a node in the graph of a workflow. 

37 

38 A node may be an operator or task group, either mapped or unmapped. 

39 """ 

40 

41 dag: Dag | None 

42 task_group: TaskGroup | None 

43 downstream_group_ids: set[str | None] 

44 upstream_task_ids: set[str] 

45 downstream_task_ids: set[str] 

46 

47 _log_config_logger_name: str | None = None 

48 _logger_name: str | None = None 

49 _cached_logger: Logger | None = None 

50 

51 def __init__(self): 

52 super().__init__() 

53 self.upstream_task_ids = set() 

54 self.downstream_task_ids = set() 

55 

56 @property 

57 def log(self) -> Logger: 

58 if self._cached_logger is not None: 

59 return self._cached_logger 

60 

61 typ = type(self) 

62 

63 logger_name: str = ( 

64 self._logger_name if self._logger_name is not None else f"{typ.__module__}.{typ.__qualname__}" 

65 ) 

66 

67 if self._log_config_logger_name: 

68 logger_name = ( 

69 f"{self._log_config_logger_name}.{logger_name}" 

70 if logger_name 

71 else self._log_config_logger_name 

72 ) 

73 

74 self._cached_logger = structlog.get_logger(logger_name) 

75 return self._cached_logger 

76 

77 @property 

78 def dag_id(self) -> str: 

79 if self.dag: 

80 return self.dag.dag_id 

81 return "_in_memory_dag_" 

82 

83 @property 

84 def node_id(self) -> str: 

85 raise NotImplementedError() 

86 

87 @property 

88 def label(self) -> str | None: 

89 tg = self.task_group 

90 if tg and tg.node_id and tg.prefix_group_id: 

91 # "task_group_id.task_id" -> "task_id" 

92 return self.node_id[len(tg.node_id) + 1 :] 

93 return self.node_id 

94 

95 @property 

96 def upstream_list(self) -> Iterable[Task]: 

97 if not self.dag: 

98 raise RuntimeError(f"Operator {self} has not been assigned to a Dag yet") 

99 return [self.dag.get_task(tid) for tid in self.upstream_task_ids] 

100 

101 @property 

102 def downstream_list(self) -> Iterable[Task]: 

103 if not self.dag: 

104 raise RuntimeError(f"Operator {self} has not been assigned to a Dag yet") 

105 return [self.dag.get_task(tid) for tid in self.downstream_task_ids] 

106 

107 def has_dag(self) -> bool: 

108 return self.dag is not None 

109 

110 def get_dag(self) -> Dag | None: 

111 return self.dag 

112 

113 def get_direct_relative_ids(self, upstream: bool = False) -> set[str]: 

114 """Get set of the direct relative ids to the current task, upstream or downstream.""" 

115 if upstream: 

116 return self.upstream_task_ids 

117 return self.downstream_task_ids 

118 

119 def get_direct_relatives(self, upstream: bool = False) -> Iterable[Task]: 

120 """Get list of the direct relatives to the current task, upstream or downstream.""" 

121 if upstream: 

122 return self.upstream_list 

123 return self.downstream_list 

124 

125 def get_flat_relative_ids(self, *, upstream: bool = False) -> set[str]: 

126 """ 

127 Get a flat set of relative IDs, upstream or downstream. 

128 

129 Will recurse each relative found in the direction specified. 

130 

131 :param upstream: Whether to look for upstream or downstream relatives. 

132 """ 

133 dag = self.get_dag() 

134 if not dag: 

135 return set() 

136 

137 relatives: set[str] = set() 

138 

139 # This is intentionally implemented as a loop, instead of calling 

140 # get_direct_relative_ids() recursively, since Python has significant 

141 # limitation on stack level, and a recursive implementation can blow up 

142 # if a DAG contains very long routes. 

143 task_ids_to_trace = self.get_direct_relative_ids(upstream) 

144 while task_ids_to_trace: 

145 task_ids_to_trace_next: set[str] = set() 

146 for task_id in task_ids_to_trace: 

147 if task_id in relatives: 

148 continue 

149 task_ids_to_trace_next.update(dag.task_dict[task_id].get_direct_relative_ids(upstream)) 

150 relatives.add(task_id) 

151 task_ids_to_trace = task_ids_to_trace_next 

152 

153 return relatives 

154 

155 def get_flat_relatives(self, upstream: bool = False) -> Collection[Task]: 

156 """Get a flat list of relatives, either upstream or downstream.""" 

157 dag = self.get_dag() 

158 if not dag: 

159 return set() 

160 return [dag.task_dict[task_id] for task_id in self.get_flat_relative_ids(upstream=upstream)] 

161 

162 def get_upstreams_follow_setups(self) -> Iterable[Task]: 

163 """All upstreams and, for each upstream setup, its respective teardowns.""" 

164 for task in self.get_flat_relatives(upstream=True): 

165 yield task 

166 if task.is_setup: 

167 for t in task.downstream_list: 

168 if t.is_teardown and t != self: 

169 yield t 

170 

171 def get_upstreams_only_setups_and_teardowns(self) -> Iterable[Task]: 

172 """ 

173 Only *relevant* upstream setups and their teardowns. 

174 

175 This method is meant to be used when we are clearing the task (non-upstream) and we need 

176 to add in the *relevant* setups and their teardowns. 

177 

178 Relevant in this case means, the setup has a teardown that is downstream of ``self``, 

179 or the setup has no teardowns. 

180 """ 

181 downstream_teardown_ids = { 

182 x.task_id for x in self.get_flat_relatives(upstream=False) if x.is_teardown 

183 } 

184 for task in self.get_flat_relatives(upstream=True): 

185 if not task.is_setup: 

186 continue 

187 has_no_teardowns = not any(x.is_teardown for x in task.downstream_list) 

188 # if task has no teardowns or has teardowns downstream of self 

189 if has_no_teardowns or task.downstream_task_ids.intersection(downstream_teardown_ids): 

190 yield task 

191 for t in task.downstream_list: 

192 if t.is_teardown and t != self: 

193 yield t 

194 

195 def get_upstreams_only_setups(self) -> Iterable[Task]: 

196 """ 

197 Return relevant upstream setups. 

198 

199 This method is meant to be used when we are checking task dependencies where we need 

200 to wait for all the upstream setups to complete before we can run the task. 

201 """ 

202 for task in self.get_upstreams_only_setups_and_teardowns(): 

203 if task.is_setup: 

204 yield task