Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/airflow/sdk/_shared/dagnode/node.py: 32%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

118 statements  

1# Licensed to the Apache Software Foundation (ASF) under one 

2# or more contributor license agreements. See the NOTICE file 

3# distributed with this work for additional information 

4# regarding copyright ownership. The ASF licenses this file 

5# to you under the Apache License, Version 2.0 (the 

6# "License"); you may not use this file except in compliance 

7# with the License. You may obtain a copy of the License at 

8# 

9# http://www.apache.org/licenses/LICENSE-2.0 

10# 

11# Unless required by applicable law or agreed to in writing, 

12# software distributed under the License is distributed on an 

13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 

14# KIND, either express or implied. See the License for the 

15# specific language governing permissions and limitations 

16# under the License. 

17 

18from __future__ import annotations 

19 

20from typing import TYPE_CHECKING, Generic, TypeVar 

21 

22import structlog 

23 

24if TYPE_CHECKING: 

25 from collections.abc import Collection, Iterable 

26 

27 from ..logging.types import Logger 

28 

29Dag = TypeVar("Dag") 

30Task = TypeVar("Task") 

31TaskGroup = TypeVar("TaskGroup") 

32 

33 

34class GenericDAGNode(Generic[Dag, Task, TaskGroup]): 

35 """ 

36 Generic class for a node in the graph of a workflow. 

37 

38 A node may be an operator or task group, either mapped or unmapped. 

39 """ 

40 

41 dag: Dag | None 

42 task_group: TaskGroup | None 

43 downstream_group_ids: set[str | None] 

44 upstream_task_ids: set[str] 

45 downstream_task_ids: set[str] 

46 

47 _log_config_logger_name: str | None = None 

48 _logger_name: str | None = None 

49 _cached_logger: Logger | None = None 

50 

51 def __init__(self): 

52 super().__init__() 

53 self.upstream_task_ids = set() 

54 self.downstream_task_ids = set() 

55 

56 @property 

57 def log(self) -> Logger: 

58 if self._cached_logger is not None: 

59 return self._cached_logger 

60 

61 typ = type(self) 

62 

63 logger_name: str = ( 

64 self._logger_name if self._logger_name is not None else f"{typ.__module__}.{typ.__qualname__}" 

65 ) 

66 

67 if self._log_config_logger_name: 

68 logger_name = ( 

69 f"{self._log_config_logger_name}.{logger_name}" 

70 if logger_name 

71 else self._log_config_logger_name 

72 ) 

73 

74 self._cached_logger = structlog.get_logger(logger_name) 

75 return self._cached_logger 

76 

77 @property 

78 def dag_id(self) -> str: 

79 if self.dag: 

80 return self.dag.dag_id 

81 return "_in_memory_dag_" 

82 

83 @property 

84 def node_id(self) -> str: 

85 raise NotImplementedError() 

86 

87 @property 

88 def label(self) -> str | None: 

89 tg = self.task_group 

90 if tg and tg.node_id and tg.prefix_group_id: 

91 # "task_group_id.task_id" -> "task_id" 

92 return self.node_id[len(tg.node_id) + 1 :] 

93 return self.node_id 

94 

95 @property 

96 def upstream_list(self) -> Iterable[Task]: 

97 if not self.dag: 

98 raise RuntimeError(f"Operator {self} has not been assigned to a Dag yet") 

99 return [self.dag.get_task(tid) for tid in self.upstream_task_ids] 

100 

101 @property 

102 def downstream_list(self) -> Iterable[Task]: 

103 if not self.dag: 

104 raise RuntimeError(f"Operator {self} has not been assigned to a Dag yet") 

105 return [self.dag.get_task(tid) for tid in self.downstream_task_ids] 

106 

107 def has_dag(self) -> bool: 

108 return self.dag is not None 

109 

110 def get_dag(self) -> Dag | None: 

111 return self.dag 

112 

113 def get_direct_relative_ids(self, upstream: bool = False) -> set[str]: 

114 """Get set of the direct relative ids to the current task, upstream or downstream.""" 

115 if upstream: 

116 return self.upstream_task_ids 

117 return self.downstream_task_ids 

118 

119 def get_direct_relatives(self, upstream: bool = False) -> Iterable[Task]: 

120 """Get list of the direct relatives to the current task, upstream or downstream.""" 

121 if upstream: 

122 return self.upstream_list 

123 return self.downstream_list 

124 

125 def get_flat_relative_ids(self, *, upstream: bool = False, depth: int | None = None) -> set[str]: 

126 """ 

127 Get a flat set of relative IDs, upstream or downstream. 

128 

129 Will recurse each relative found in the direction specified. 

130 

131 :param upstream: Whether to look for upstream or downstream relatives. 

132 :param depth: Maximum number of levels to traverse. If None, traverses all levels. 

133 Must be non-negative. 

134 """ 

135 if depth is not None and depth < 0: 

136 raise ValueError(f"depth must be non-negative, got {depth}") 

137 

138 dag = self.get_dag() 

139 if not dag: 

140 return set() 

141 

142 relatives: set[str] = set() 

143 

144 # This is intentionally implemented as a loop, instead of calling 

145 # get_direct_relative_ids() recursively, since Python has significant 

146 # limitation on stack level, and a recursive implementation can blow up 

147 # if a DAG contains very long routes. 

148 task_ids_to_trace = self.get_direct_relative_ids(upstream) 

149 levels_remaining = depth 

150 while task_ids_to_trace: 

151 # if depth is set we have bounded traversal and should break when 

152 # there are no more levels remaining 

153 if levels_remaining is not None and levels_remaining <= 0: 

154 break 

155 task_ids_to_trace_next: set[str] = set() 

156 for task_id in task_ids_to_trace: 

157 if task_id in relatives: 

158 continue 

159 task_ids_to_trace_next.update(dag.task_dict[task_id].get_direct_relative_ids(upstream)) 

160 relatives.add(task_id) 

161 task_ids_to_trace = task_ids_to_trace_next 

162 if levels_remaining is not None: 

163 levels_remaining -= 1 

164 

165 return relatives 

166 

167 def get_flat_relatives(self, upstream: bool = False, depth: int | None = None) -> Collection[Task]: 

168 """ 

169 Get a flat list of relatives, either upstream or downstream. 

170 

171 :param upstream: Whether to look for upstream or downstream relatives. 

172 :param depth: Maximum number of levels to traverse. If None, traverses all levels. 

173 Must be non-negative. 

174 """ 

175 dag = self.get_dag() 

176 if not dag: 

177 return set() 

178 return [ 

179 dag.task_dict[task_id] for task_id in self.get_flat_relative_ids(upstream=upstream, depth=depth) 

180 ] 

181 

182 def get_upstreams_follow_setups(self, depth: int | None = None) -> Iterable[Task]: 

183 """ 

184 All upstreams and, for each upstream setup, its respective teardowns. 

185 

186 :param depth: Maximum number of levels to traverse. If None, traverses all levels. 

187 Must be non-negative. 

188 """ 

189 for task in self.get_flat_relatives(upstream=True, depth=depth): 

190 yield task 

191 if task.is_setup: 

192 for t in task.downstream_list: 

193 if t.is_teardown and t != self: 

194 yield t 

195 

196 def get_upstreams_only_setups_and_teardowns(self) -> Iterable[Task]: 

197 """ 

198 Only *relevant* upstream setups and their teardowns. 

199 

200 This method is meant to be used when we are clearing the task (non-upstream) and we need 

201 to add in the *relevant* setups and their teardowns. 

202 

203 Relevant in this case means, the setup has a teardown that is downstream of ``self``, 

204 or the setup has no teardowns. 

205 """ 

206 downstream_teardown_ids = { 

207 x.task_id for x in self.get_flat_relatives(upstream=False) if x.is_teardown 

208 } 

209 for task in self.get_flat_relatives(upstream=True): 

210 if not task.is_setup: 

211 continue 

212 has_no_teardowns = not any(x.is_teardown for x in task.downstream_list) 

213 # if task has no teardowns or has teardowns downstream of self 

214 if has_no_teardowns or task.downstream_task_ids.intersection(downstream_teardown_ids): 

215 yield task 

216 for t in task.downstream_list: 

217 if t.is_teardown and t != self: 

218 yield t 

219 

220 def get_upstreams_only_setups(self) -> Iterable[Task]: 

221 """ 

222 Return relevant upstream setups. 

223 

224 This method is meant to be used when we are checking task dependencies where we need 

225 to wait for all the upstream setups to complete before we can run the task. 

226 """ 

227 for task in self.get_upstreams_only_setups_and_teardowns(): 

228 if task.is_setup: 

229 yield task