Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/airflow/sdk/_shared/dagnode/node.py: 38%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

136 statements  

1# Licensed to the Apache Software Foundation (ASF) under one 

2# or more contributor license agreements. See the NOTICE file 

3# distributed with this work for additional information 

4# regarding copyright ownership. The ASF licenses this file 

5# to you under the Apache License, Version 2.0 (the 

6# "License"); you may not use this file except in compliance 

7# with the License. You may obtain a copy of the License at 

8# 

9# http://www.apache.org/licenses/LICENSE-2.0 

10# 

11# Unless required by applicable law or agreed to in writing, 

12# software distributed under the License is distributed on an 

13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 

14# KIND, either express or implied. See the License for the 

15# specific language governing permissions and limitations 

16# under the License. 

17 

18from __future__ import annotations 

19 

20from typing import TYPE_CHECKING, Any, Generic, Protocol, TypeVar 

21 

22import structlog 

23 

24if TYPE_CHECKING: 

25 import sys 

26 from collections.abc import Collection, Iterable 

27 

28 # Replicate `airflow.typing_compat.Self` to avoid illegal imports 

29 if sys.version_info >= (3, 11): 

30 from typing import Self 

31 else: 

32 from typing_extensions import Self 

33 

34 from ..logging.types import Logger 

35 

36 

37class DagProtocol(Protocol): 

38 """Protocol defining the minimum interface required for Dag generic type.""" 

39 

40 dag_id: str 

41 task_dict: dict[str, Any] 

42 

43 def get_task(self, tid: str) -> Any: 

44 """Retrieve a task by its task ID.""" 

45 ... 

46 

47 

48class TaskProtocol(Protocol): 

49 """Protocol defining the minimum interface required for Task generic type.""" 

50 

51 task_id: str 

52 is_setup: bool 

53 is_teardown: bool 

54 downstream_list: Iterable[Self] 

55 downstream_task_ids: set[str] 

56 

57 

58class TaskGroupProtocol(Protocol): 

59 """Protocol defining the minimum interface required for TaskGroup generic type.""" 

60 

61 node_id: str 

62 prefix_group_id: bool 

63 

64 

65Dag = TypeVar("Dag", bound=DagProtocol) 

66Task = TypeVar("Task", bound=TaskProtocol) 

67TaskGroup = TypeVar("TaskGroup", bound=TaskGroupProtocol) 

68 

69 

70class GenericDAGNode(Generic[Dag, Task, TaskGroup]): 

71 """ 

72 Generic class for a node in the graph of a workflow. 

73 

74 A node may be an operator or task group, either mapped or unmapped. 

75 """ 

76 

77 dag: Dag | None 

78 task_group: TaskGroup | None 

79 downstream_group_ids: set[str | None] 

80 upstream_task_ids: set[str] 

81 downstream_task_ids: set[str] 

82 

83 _log_config_logger_name: str | None = None 

84 _logger_name: str | None = None 

85 _cached_logger: Logger | None = None 

86 

87 def __init__(self): 

88 super().__init__() 

89 self.upstream_task_ids = set() 

90 self.downstream_task_ids = set() 

91 

92 @property 

93 def log(self) -> Logger: 

94 if self._cached_logger is not None: 

95 return self._cached_logger 

96 

97 typ = type(self) 

98 

99 logger_name: str = ( 

100 self._logger_name if self._logger_name is not None else f"{typ.__module__}.{typ.__qualname__}" 

101 ) 

102 

103 if self._log_config_logger_name: 

104 logger_name = ( 

105 f"{self._log_config_logger_name}.{logger_name}" 

106 if logger_name 

107 else self._log_config_logger_name 

108 ) 

109 

110 self._cached_logger = structlog.get_logger(logger_name) 

111 return self._cached_logger 

112 

113 @property 

114 def dag_id(self) -> str: 

115 if self.dag: 

116 return self.dag.dag_id 

117 return "_in_memory_dag_" 

118 

119 @property 

120 def node_id(self) -> str: 

121 raise NotImplementedError() 

122 

123 @property 

124 def label(self) -> str | None: 

125 tg = self.task_group 

126 if tg and tg.node_id and tg.prefix_group_id: 

127 # "task_group_id.task_id" -> "task_id" 

128 return self.node_id[len(tg.node_id) + 1 :] 

129 return self.node_id 

130 

131 @property 

132 def upstream_list(self) -> Iterable[Task]: 

133 if not self.dag: 

134 raise RuntimeError(f"Operator {self} has not been assigned to a Dag yet") 

135 return [self.dag.get_task(tid) for tid in self.upstream_task_ids] 

136 

137 @property 

138 def downstream_list(self) -> Iterable[Task]: 

139 if not self.dag: 

140 raise RuntimeError(f"Operator {self} has not been assigned to a Dag yet") 

141 return [self.dag.get_task(tid) for tid in self.downstream_task_ids] 

142 

143 def has_dag(self) -> bool: 

144 return self.dag is not None 

145 

146 def get_dag(self) -> Dag | None: 

147 return self.dag 

148 

149 def get_direct_relative_ids(self, upstream: bool = False) -> set[str]: 

150 """Get set of the direct relative ids to the current task, upstream or downstream.""" 

151 if upstream: 

152 return self.upstream_task_ids 

153 return self.downstream_task_ids 

154 

155 def get_direct_relatives(self, upstream: bool = False) -> Iterable[Task]: 

156 """Get list of the direct relatives to the current task, upstream or downstream.""" 

157 if upstream: 

158 return self.upstream_list 

159 return self.downstream_list 

160 

161 def get_flat_relative_ids(self, *, upstream: bool = False, depth: int | None = None) -> set[str]: 

162 """ 

163 Get a flat set of relative IDs, upstream or downstream. 

164 

165 Will recurse each relative found in the direction specified. 

166 

167 :param upstream: Whether to look for upstream or downstream relatives. 

168 :param depth: Maximum number of levels to traverse. If None, traverses all levels. 

169 Must be non-negative. 

170 """ 

171 if depth is not None and depth < 0: 

172 raise ValueError(f"depth must be non-negative, got {depth}") 

173 

174 dag = self.get_dag() 

175 if not dag: 

176 return set() 

177 

178 relatives: set[str] = set() 

179 

180 # This is intentionally implemented as a loop, instead of calling 

181 # get_direct_relative_ids() recursively, since Python has significant 

182 # limitation on stack level, and a recursive implementation can blow up 

183 # if a DAG contains very long routes. 

184 task_ids_to_trace = self.get_direct_relative_ids(upstream) 

185 levels_remaining = depth 

186 while task_ids_to_trace: 

187 # if depth is set we have bounded traversal and should break when 

188 # there are no more levels remaining 

189 if levels_remaining is not None and levels_remaining <= 0: 

190 break 

191 task_ids_to_trace_next: set[str] = set() 

192 for task_id in task_ids_to_trace: 

193 if task_id in relatives: 

194 continue 

195 task_ids_to_trace_next.update(dag.task_dict[task_id].get_direct_relative_ids(upstream)) 

196 relatives.add(task_id) 

197 task_ids_to_trace = task_ids_to_trace_next 

198 if levels_remaining is not None: 

199 levels_remaining -= 1 

200 

201 return relatives 

202 

203 def get_flat_relatives(self, upstream: bool = False, depth: int | None = None) -> Collection[Task]: 

204 """ 

205 Get a flat list of relatives, either upstream or downstream. 

206 

207 :param upstream: Whether to look for upstream or downstream relatives. 

208 :param depth: Maximum number of levels to traverse. If None, traverses all levels. 

209 Must be non-negative. 

210 """ 

211 dag = self.get_dag() 

212 if not dag: 

213 return set() 

214 return [ 

215 dag.task_dict[task_id] for task_id in self.get_flat_relative_ids(upstream=upstream, depth=depth) 

216 ] 

217 

218 def get_upstreams_follow_setups(self, depth: int | None = None) -> Iterable[Task]: 

219 """ 

220 All upstreams and, for each upstream setup, its respective teardowns. 

221 

222 :param depth: Maximum number of levels to traverse. If None, traverses all levels. 

223 Must be non-negative. 

224 """ 

225 for task in self.get_flat_relatives(upstream=True, depth=depth): 

226 yield task 

227 if task.is_setup: 

228 for t in task.downstream_list: 

229 if t.is_teardown and t != self: 

230 yield t 

231 

232 def get_upstreams_only_setups_and_teardowns(self) -> Iterable[Task]: 

233 """ 

234 Only *relevant* upstream setups and their teardowns. 

235 

236 This method is meant to be used when we are clearing the task (non-upstream) and we need 

237 to add in the *relevant* setups and their teardowns. 

238 

239 Relevant in this case means, the setup has a teardown that is downstream of ``self``, 

240 or the setup has no teardowns. 

241 """ 

242 downstream_teardown_ids = { 

243 x.task_id for x in self.get_flat_relatives(upstream=False) if x.is_teardown 

244 } 

245 for task in self.get_flat_relatives(upstream=True): 

246 if not task.is_setup: 

247 continue 

248 has_no_teardowns = not any(x.is_teardown for x in task.downstream_list) 

249 # if task has no teardowns or has teardowns downstream of self 

250 if has_no_teardowns or task.downstream_task_ids.intersection(downstream_teardown_ids): 

251 yield task 

252 for t in task.downstream_list: 

253 if t.is_teardown and t != self: 

254 yield t 

255 

256 def get_upstreams_only_setups(self) -> Iterable[Task]: 

257 """ 

258 Return relevant upstream setups. 

259 

260 This method is meant to be used when we are checking task dependencies where we need 

261 to wait for all the upstream setups to complete before we can run the task. 

262 """ 

263 for task in self.get_upstreams_only_setups_and_teardowns(): 

264 if task.is_setup: 

265 yield task