Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/airflow/sdk/definitions/_internal/node.py: 32%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

183 statements  

1# Licensed to the Apache Software Foundation (ASF) under one 

2# or more contributor license agreements. See the NOTICE file 

3# distributed with this work for additional information 

4# regarding copyright ownership. The ASF licenses this file 

5# to you under the Apache License, Version 2.0 (the 

6# "License"); you may not use this file except in compliance 

7# with the License. You may obtain a copy of the License at 

8# 

9# http://www.apache.org/licenses/LICENSE-2.0 

10# 

11# Unless required by applicable law or agreed to in writing, 

12# software distributed under the License is distributed on an 

13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 

14# KIND, either express or implied. See the License for the 

15# specific language governing permissions and limitations 

16# under the License. 

17 

18from __future__ import annotations 

19 

20import re 

21from abc import ABCMeta, abstractmethod 

22from collections.abc import Collection, Iterable, Sequence 

23from datetime import datetime 

24from typing import TYPE_CHECKING, Any 

25 

26import structlog 

27 

28from airflow.sdk.definitions._internal.mixins import DependencyMixin 

29 

30if TYPE_CHECKING: 

31 from airflow.sdk.definitions.dag import DAG 

32 from airflow.sdk.definitions.edges import EdgeModifier 

33 from airflow.sdk.definitions.taskgroup import TaskGroup 

34 from airflow.sdk.types import Logger, Operator 

35 from airflow.serialization.enums import DagAttributeTypes 

36 

37 

38KEY_REGEX = re.compile(r"^[\w.-]+$") 

39GROUP_KEY_REGEX = re.compile(r"^[\w-]+$") 

40CAMELCASE_TO_SNAKE_CASE_REGEX = re.compile(r"(?!^)([A-Z]+)") 

41 

42 

43def validate_key(k: str, max_length: int = 250): 

44 """Validate value used as a key.""" 

45 if not isinstance(k, str): 

46 raise TypeError(f"The key has to be a string and is {type(k)}:{k}") 

47 if (length := len(k)) > max_length: 

48 raise ValueError(f"The key has to be less than {max_length} characters, not {length}") 

49 if not KEY_REGEX.match(k): 

50 raise ValueError( 

51 f"The key {k!r} has to be made of alphanumeric characters, dashes, " 

52 f"dots, and underscores exclusively" 

53 ) 

54 

55 

56def validate_group_key(k: str, max_length: int = 200): 

57 """Validate value used as a group key.""" 

58 if not isinstance(k, str): 

59 raise TypeError(f"The key has to be a string and is {type(k)}:{k}") 

60 if (length := len(k)) > max_length: 

61 raise ValueError(f"The key has to be less than {max_length} characters, not {length}") 

62 if not GROUP_KEY_REGEX.match(k): 

63 raise ValueError( 

64 f"The key {k!r} has to be made of alphanumeric characters, dashes, and underscores exclusively" 

65 ) 

66 

67 

68class DAGNode(DependencyMixin, metaclass=ABCMeta): 

69 """ 

70 A base class for a node in the graph of a workflow. 

71 

72 A node may be an Operator or a Task Group, either mapped or unmapped. 

73 """ 

74 

75 dag: DAG | None 

76 task_group: TaskGroup | None 

77 """The task_group that contains this node""" 

78 start_date: datetime | None 

79 end_date: datetime | None 

80 upstream_task_ids: set[str] 

81 downstream_task_ids: set[str] 

82 

83 _log_config_logger_name: str | None = None 

84 _logger_name: str | None = None 

85 _cached_logger: Logger | None = None 

86 

87 def __init__(self): 

88 self.upstream_task_ids = set() 

89 self.downstream_task_ids = set() 

90 super().__init__() 

91 

92 def get_dag(self) -> DAG | None: 

93 return self.dag 

94 

95 @property 

96 @abstractmethod 

97 def node_id(self) -> str: 

98 raise NotImplementedError() 

99 

100 @property 

101 def label(self) -> str | None: 

102 tg = self.task_group 

103 if tg and tg.node_id and tg.prefix_group_id: 

104 # "task_group_id.task_id" -> "task_id" 

105 return self.node_id[len(tg.node_id) + 1 :] 

106 return self.node_id 

107 

108 def has_dag(self) -> bool: 

109 return self.dag is not None 

110 

111 @property 

112 def dag_id(self) -> str: 

113 """Returns dag id if it has one or an adhoc/meaningless ID.""" 

114 if self.dag: 

115 return self.dag.dag_id 

116 return "_in_memory_dag_" 

117 

118 @property 

119 def log(self) -> Logger: 

120 """ 

121 Get a logger for this node. 

122 

123 The logger name is determined by: 

124 1. Using _logger_name if provided 

125 2. Otherwise, using the class's module and qualified name 

126 3. Prefixing with _log_config_logger_name if set 

127 """ 

128 if self._cached_logger is not None: 

129 return self._cached_logger 

130 

131 typ = type(self) 

132 

133 logger_name: str = ( 

134 self._logger_name if self._logger_name is not None else f"{typ.__module__}.{typ.__qualname__}" 

135 ) 

136 

137 if self._log_config_logger_name: 

138 logger_name = ( 

139 f"{self._log_config_logger_name}.{logger_name}" 

140 if logger_name 

141 else self._log_config_logger_name 

142 ) 

143 

144 self._cached_logger = structlog.get_logger(logger_name) 

145 return self._cached_logger 

146 

147 @property 

148 @abstractmethod 

149 def roots(self) -> Sequence[DAGNode]: 

150 raise NotImplementedError() 

151 

152 @property 

153 @abstractmethod 

154 def leaves(self) -> Sequence[DAGNode]: 

155 raise NotImplementedError() 

156 

157 def _set_relatives( 

158 self, 

159 task_or_task_list: DependencyMixin | Sequence[DependencyMixin], 

160 upstream: bool = False, 

161 edge_modifier: EdgeModifier | None = None, 

162 ) -> None: 

163 """Set relatives for the task or task list.""" 

164 from airflow.sdk.bases.operator import BaseOperator 

165 from airflow.sdk.definitions.mappedoperator import MappedOperator 

166 

167 if not isinstance(task_or_task_list, Sequence): 

168 task_or_task_list = [task_or_task_list] 

169 

170 task_list: list[BaseOperator | MappedOperator] = [] 

171 for task_object in task_or_task_list: 

172 task_object.update_relative(self, not upstream, edge_modifier=edge_modifier) 

173 relatives = task_object.leaves if upstream else task_object.roots 

174 for task in relatives: 

175 if not isinstance(task, (BaseOperator, MappedOperator)): 

176 raise TypeError( 

177 f"Relationships can only be set between Operators; received {task.__class__.__name__}" 

178 ) 

179 task_list.append(task) 

180 

181 # relationships can only be set if the tasks share a single Dag. Tasks 

182 # without a Dag are assigned to that Dag. 

183 dags: set[DAG] = {task.dag for task in [*self.roots, *task_list] if task.has_dag() and task.dag} 

184 

185 if len(dags) > 1: 

186 raise RuntimeError(f"Tried to set relationships between tasks in more than one Dag: {dags}") 

187 if len(dags) == 1: 

188 dag = dags.pop() 

189 else: 

190 raise ValueError( 

191 "Tried to create relationships between tasks that don't have Dags yet. " 

192 f"Set the Dag for at least one task and try again: {[self, *task_list]}" 

193 ) 

194 

195 if not self.has_dag(): 

196 # If this task does not yet have a Dag, add it to the same Dag as the other task. 

197 self.dag = dag 

198 

199 for task in task_list: 

200 if dag and not task.has_dag(): 

201 # If the other task does not yet have a Dag, add it to the same Dag as this task and 

202 dag.add_task(task) # type: ignore[arg-type] 

203 if upstream: 

204 task.downstream_task_ids.add(self.node_id) 

205 self.upstream_task_ids.add(task.node_id) 

206 if edge_modifier: 

207 edge_modifier.add_edge_info(dag, task.node_id, self.node_id) 

208 else: 

209 self.downstream_task_ids.add(task.node_id) 

210 task.upstream_task_ids.add(self.node_id) 

211 if edge_modifier: 

212 edge_modifier.add_edge_info(dag, self.node_id, task.node_id) 

213 

214 def set_downstream( 

215 self, 

216 task_or_task_list: DependencyMixin | Sequence[DependencyMixin], 

217 edge_modifier: EdgeModifier | None = None, 

218 ) -> None: 

219 """Set a node (or nodes) to be directly downstream from the current node.""" 

220 self._set_relatives(task_or_task_list, upstream=False, edge_modifier=edge_modifier) 

221 

222 def set_upstream( 

223 self, 

224 task_or_task_list: DependencyMixin | Sequence[DependencyMixin], 

225 edge_modifier: EdgeModifier | None = None, 

226 ) -> None: 

227 """Set a node (or nodes) to be directly upstream from the current node.""" 

228 self._set_relatives(task_or_task_list, upstream=True, edge_modifier=edge_modifier) 

229 

230 @property 

231 def downstream_list(self) -> Iterable[Operator]: 

232 """List of nodes directly downstream.""" 

233 if not self.dag: 

234 raise RuntimeError(f"Operator {self} has not been assigned to a Dag yet") 

235 return [self.dag.get_task(tid) for tid in self.downstream_task_ids] 

236 

237 @property 

238 def upstream_list(self) -> Iterable[Operator]: 

239 """List of nodes directly upstream.""" 

240 if not self.dag: 

241 raise RuntimeError(f"Operator {self} has not been assigned to a Dag yet") 

242 return [self.dag.get_task(tid) for tid in self.upstream_task_ids] 

243 

244 def get_direct_relative_ids(self, upstream: bool = False) -> set[str]: 

245 """Get set of the direct relative ids to the current task, upstream or downstream.""" 

246 if upstream: 

247 return self.upstream_task_ids 

248 return self.downstream_task_ids 

249 

250 def get_direct_relatives(self, upstream: bool = False) -> Iterable[Operator]: 

251 """Get list of the direct relatives to the current task, upstream or downstream.""" 

252 if upstream: 

253 return self.upstream_list 

254 return self.downstream_list 

255 

256 def get_flat_relative_ids(self, *, upstream: bool = False) -> set[str]: 

257 """ 

258 Get a flat set of relative IDs, upstream or downstream. 

259 

260 Will recurse each relative found in the direction specified. 

261 

262 :param upstream: Whether to look for upstream or downstream relatives. 

263 """ 

264 dag = self.get_dag() 

265 if not dag: 

266 return set() 

267 

268 relatives: set[str] = set() 

269 

270 # This is intentionally implemented as a loop, instead of calling 

271 # get_direct_relative_ids() recursively, since Python has significant 

272 # limitation on stack level, and a recursive implementation can blow up 

273 # if a DAG contains very long routes. 

274 task_ids_to_trace = self.get_direct_relative_ids(upstream) 

275 while task_ids_to_trace: 

276 task_ids_to_trace_next: set[str] = set() 

277 for task_id in task_ids_to_trace: 

278 if task_id in relatives: 

279 continue 

280 task_ids_to_trace_next.update(dag.task_dict[task_id].get_direct_relative_ids(upstream)) 

281 relatives.add(task_id) 

282 task_ids_to_trace = task_ids_to_trace_next 

283 

284 return relatives 

285 

286 def get_flat_relatives(self, upstream: bool = False) -> Collection[Operator]: 

287 """Get a flat list of relatives, either upstream or downstream.""" 

288 dag = self.get_dag() 

289 if not dag: 

290 return set() 

291 return [dag.task_dict[task_id] for task_id in self.get_flat_relative_ids(upstream=upstream)] 

292 

293 def get_upstreams_follow_setups(self) -> Iterable[Operator]: 

294 """All upstreams and, for each upstream setup, its respective teardowns.""" 

295 for task in self.get_flat_relatives(upstream=True): 

296 yield task 

297 if task.is_setup: 

298 for t in task.downstream_list: 

299 if t.is_teardown and t != self: 

300 yield t 

301 

302 def get_upstreams_only_setups_and_teardowns(self) -> Iterable[Operator]: 

303 """ 

304 Only *relevant* upstream setups and their teardowns. 

305 

306 This method is meant to be used when we are clearing the task (non-upstream) and we need 

307 to add in the *relevant* setups and their teardowns. 

308 

309 Relevant in this case means, the setup has a teardown that is downstream of ``self``, 

310 or the setup has no teardowns. 

311 """ 

312 downstream_teardown_ids = { 

313 x.task_id for x in self.get_flat_relatives(upstream=False) if x.is_teardown 

314 } 

315 for task in self.get_flat_relatives(upstream=True): 

316 if not task.is_setup: 

317 continue 

318 has_no_teardowns = not any(x.is_teardown for x in task.downstream_list) 

319 # if task has no teardowns or has teardowns downstream of self 

320 if has_no_teardowns or task.downstream_task_ids.intersection(downstream_teardown_ids): 

321 yield task 

322 for t in task.downstream_list: 

323 if t.is_teardown and t != self: 

324 yield t 

325 

326 def get_upstreams_only_setups(self) -> Iterable[Operator]: 

327 """ 

328 Return relevant upstream setups. 

329 

330 This method is meant to be used when we are checking task dependencies where we need 

331 to wait for all the upstream setups to complete before we can run the task. 

332 """ 

333 for task in self.get_upstreams_only_setups_and_teardowns(): 

334 if task.is_setup: 

335 yield task 

336 

337 def serialize_for_task_group(self) -> tuple[DagAttributeTypes, Any]: 

338 """Serialize a task group's content; used by TaskGroupSerialization.""" 

339 raise NotImplementedError()