1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements. See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership. The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License. You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied. See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18from __future__ import annotations
19
20import re
21from abc import ABCMeta, abstractmethod
22from collections.abc import Collection, Iterable, Sequence
23from datetime import datetime
24from typing import TYPE_CHECKING, Any
25
26import structlog
27
28from airflow.sdk.definitions._internal.mixins import DependencyMixin
29
30if TYPE_CHECKING:
31 from airflow.sdk.definitions.dag import DAG
32 from airflow.sdk.definitions.edges import EdgeModifier
33 from airflow.sdk.definitions.taskgroup import TaskGroup
34 from airflow.sdk.types import Logger, Operator
35 from airflow.serialization.enums import DagAttributeTypes
36
37
38KEY_REGEX = re.compile(r"^[\w.-]+$")
39GROUP_KEY_REGEX = re.compile(r"^[\w-]+$")
40CAMELCASE_TO_SNAKE_CASE_REGEX = re.compile(r"(?!^)([A-Z]+)")
41
42
43def validate_key(k: str, max_length: int = 250):
44 """Validate value used as a key."""
45 if not isinstance(k, str):
46 raise TypeError(f"The key has to be a string and is {type(k)}:{k}")
47 if (length := len(k)) > max_length:
48 raise ValueError(f"The key has to be less than {max_length} characters, not {length}")
49 if not KEY_REGEX.match(k):
50 raise ValueError(
51 f"The key {k!r} has to be made of alphanumeric characters, dashes, "
52 f"dots, and underscores exclusively"
53 )
54
55
56def validate_group_key(k: str, max_length: int = 200):
57 """Validate value used as a group key."""
58 if not isinstance(k, str):
59 raise TypeError(f"The key has to be a string and is {type(k)}:{k}")
60 if (length := len(k)) > max_length:
61 raise ValueError(f"The key has to be less than {max_length} characters, not {length}")
62 if not GROUP_KEY_REGEX.match(k):
63 raise ValueError(
64 f"The key {k!r} has to be made of alphanumeric characters, dashes, and underscores exclusively"
65 )
66
67
68class DAGNode(DependencyMixin, metaclass=ABCMeta):
69 """
70 A base class for a node in the graph of a workflow.
71
72 A node may be an Operator or a Task Group, either mapped or unmapped.
73 """
74
75 dag: DAG | None
76 task_group: TaskGroup | None
77 """The task_group that contains this node"""
78 start_date: datetime | None
79 end_date: datetime | None
80 upstream_task_ids: set[str]
81 downstream_task_ids: set[str]
82
83 _log_config_logger_name: str | None = None
84 _logger_name: str | None = None
85 _cached_logger: Logger | None = None
86
87 def __init__(self):
88 self.upstream_task_ids = set()
89 self.downstream_task_ids = set()
90 super().__init__()
91
92 def get_dag(self) -> DAG | None:
93 return self.dag
94
95 @property
96 @abstractmethod
97 def node_id(self) -> str:
98 raise NotImplementedError()
99
100 @property
101 def label(self) -> str | None:
102 tg = self.task_group
103 if tg and tg.node_id and tg.prefix_group_id:
104 # "task_group_id.task_id" -> "task_id"
105 return self.node_id[len(tg.node_id) + 1 :]
106 return self.node_id
107
108 def has_dag(self) -> bool:
109 return self.dag is not None
110
111 @property
112 def dag_id(self) -> str:
113 """Returns dag id if it has one or an adhoc/meaningless ID."""
114 if self.dag:
115 return self.dag.dag_id
116 return "_in_memory_dag_"
117
118 @property
119 def log(self) -> Logger:
120 """
121 Get a logger for this node.
122
123 The logger name is determined by:
124 1. Using _logger_name if provided
125 2. Otherwise, using the class's module and qualified name
126 3. Prefixing with _log_config_logger_name if set
127 """
128 if self._cached_logger is not None:
129 return self._cached_logger
130
131 typ = type(self)
132
133 logger_name: str = (
134 self._logger_name if self._logger_name is not None else f"{typ.__module__}.{typ.__qualname__}"
135 )
136
137 if self._log_config_logger_name:
138 logger_name = (
139 f"{self._log_config_logger_name}.{logger_name}"
140 if logger_name
141 else self._log_config_logger_name
142 )
143
144 self._cached_logger = structlog.get_logger(logger_name)
145 return self._cached_logger
146
147 @property
148 @abstractmethod
149 def roots(self) -> Sequence[DAGNode]:
150 raise NotImplementedError()
151
152 @property
153 @abstractmethod
154 def leaves(self) -> Sequence[DAGNode]:
155 raise NotImplementedError()
156
157 def _set_relatives(
158 self,
159 task_or_task_list: DependencyMixin | Sequence[DependencyMixin],
160 upstream: bool = False,
161 edge_modifier: EdgeModifier | None = None,
162 ) -> None:
163 """Set relatives for the task or task list."""
164 from airflow.sdk.bases.operator import BaseOperator
165 from airflow.sdk.definitions.mappedoperator import MappedOperator
166
167 if not isinstance(task_or_task_list, Sequence):
168 task_or_task_list = [task_or_task_list]
169
170 task_list: list[BaseOperator | MappedOperator] = []
171 for task_object in task_or_task_list:
172 task_object.update_relative(self, not upstream, edge_modifier=edge_modifier)
173 relatives = task_object.leaves if upstream else task_object.roots
174 for task in relatives:
175 if not isinstance(task, (BaseOperator, MappedOperator)):
176 raise TypeError(
177 f"Relationships can only be set between Operators; received {task.__class__.__name__}"
178 )
179 task_list.append(task)
180
181 # relationships can only be set if the tasks share a single Dag. Tasks
182 # without a Dag are assigned to that Dag.
183 dags: set[DAG] = {task.dag for task in [*self.roots, *task_list] if task.has_dag() and task.dag}
184
185 if len(dags) > 1:
186 raise RuntimeError(f"Tried to set relationships between tasks in more than one Dag: {dags}")
187 if len(dags) == 1:
188 dag = dags.pop()
189 else:
190 raise ValueError(
191 "Tried to create relationships between tasks that don't have Dags yet. "
192 f"Set the Dag for at least one task and try again: {[self, *task_list]}"
193 )
194
195 if not self.has_dag():
196 # If this task does not yet have a Dag, add it to the same Dag as the other task.
197 self.dag = dag
198
199 for task in task_list:
200 if dag and not task.has_dag():
201 # If the other task does not yet have a Dag, add it to the same Dag as this task and
202 dag.add_task(task) # type: ignore[arg-type]
203 if upstream:
204 task.downstream_task_ids.add(self.node_id)
205 self.upstream_task_ids.add(task.node_id)
206 if edge_modifier:
207 edge_modifier.add_edge_info(dag, task.node_id, self.node_id)
208 else:
209 self.downstream_task_ids.add(task.node_id)
210 task.upstream_task_ids.add(self.node_id)
211 if edge_modifier:
212 edge_modifier.add_edge_info(dag, self.node_id, task.node_id)
213
214 def set_downstream(
215 self,
216 task_or_task_list: DependencyMixin | Sequence[DependencyMixin],
217 edge_modifier: EdgeModifier | None = None,
218 ) -> None:
219 """Set a node (or nodes) to be directly downstream from the current node."""
220 self._set_relatives(task_or_task_list, upstream=False, edge_modifier=edge_modifier)
221
222 def set_upstream(
223 self,
224 task_or_task_list: DependencyMixin | Sequence[DependencyMixin],
225 edge_modifier: EdgeModifier | None = None,
226 ) -> None:
227 """Set a node (or nodes) to be directly upstream from the current node."""
228 self._set_relatives(task_or_task_list, upstream=True, edge_modifier=edge_modifier)
229
230 @property
231 def downstream_list(self) -> Iterable[Operator]:
232 """List of nodes directly downstream."""
233 if not self.dag:
234 raise RuntimeError(f"Operator {self} has not been assigned to a Dag yet")
235 return [self.dag.get_task(tid) for tid in self.downstream_task_ids]
236
237 @property
238 def upstream_list(self) -> Iterable[Operator]:
239 """List of nodes directly upstream."""
240 if not self.dag:
241 raise RuntimeError(f"Operator {self} has not been assigned to a Dag yet")
242 return [self.dag.get_task(tid) for tid in self.upstream_task_ids]
243
244 def get_direct_relative_ids(self, upstream: bool = False) -> set[str]:
245 """Get set of the direct relative ids to the current task, upstream or downstream."""
246 if upstream:
247 return self.upstream_task_ids
248 return self.downstream_task_ids
249
250 def get_direct_relatives(self, upstream: bool = False) -> Iterable[Operator]:
251 """Get list of the direct relatives to the current task, upstream or downstream."""
252 if upstream:
253 return self.upstream_list
254 return self.downstream_list
255
256 def get_flat_relative_ids(self, *, upstream: bool = False) -> set[str]:
257 """
258 Get a flat set of relative IDs, upstream or downstream.
259
260 Will recurse each relative found in the direction specified.
261
262 :param upstream: Whether to look for upstream or downstream relatives.
263 """
264 dag = self.get_dag()
265 if not dag:
266 return set()
267
268 relatives: set[str] = set()
269
270 # This is intentionally implemented as a loop, instead of calling
271 # get_direct_relative_ids() recursively, since Python has significant
272 # limitation on stack level, and a recursive implementation can blow up
273 # if a DAG contains very long routes.
274 task_ids_to_trace = self.get_direct_relative_ids(upstream)
275 while task_ids_to_trace:
276 task_ids_to_trace_next: set[str] = set()
277 for task_id in task_ids_to_trace:
278 if task_id in relatives:
279 continue
280 task_ids_to_trace_next.update(dag.task_dict[task_id].get_direct_relative_ids(upstream))
281 relatives.add(task_id)
282 task_ids_to_trace = task_ids_to_trace_next
283
284 return relatives
285
286 def get_flat_relatives(self, upstream: bool = False) -> Collection[Operator]:
287 """Get a flat list of relatives, either upstream or downstream."""
288 dag = self.get_dag()
289 if not dag:
290 return set()
291 return [dag.task_dict[task_id] for task_id in self.get_flat_relative_ids(upstream=upstream)]
292
293 def get_upstreams_follow_setups(self) -> Iterable[Operator]:
294 """All upstreams and, for each upstream setup, its respective teardowns."""
295 for task in self.get_flat_relatives(upstream=True):
296 yield task
297 if task.is_setup:
298 for t in task.downstream_list:
299 if t.is_teardown and t != self:
300 yield t
301
302 def get_upstreams_only_setups_and_teardowns(self) -> Iterable[Operator]:
303 """
304 Only *relevant* upstream setups and their teardowns.
305
306 This method is meant to be used when we are clearing the task (non-upstream) and we need
307 to add in the *relevant* setups and their teardowns.
308
309 Relevant in this case means, the setup has a teardown that is downstream of ``self``,
310 or the setup has no teardowns.
311 """
312 downstream_teardown_ids = {
313 x.task_id for x in self.get_flat_relatives(upstream=False) if x.is_teardown
314 }
315 for task in self.get_flat_relatives(upstream=True):
316 if not task.is_setup:
317 continue
318 has_no_teardowns = not any(x.is_teardown for x in task.downstream_list)
319 # if task has no teardowns or has teardowns downstream of self
320 if has_no_teardowns or task.downstream_task_ids.intersection(downstream_teardown_ids):
321 yield task
322 for t in task.downstream_list:
323 if t.is_teardown and t != self:
324 yield t
325
326 def get_upstreams_only_setups(self) -> Iterable[Operator]:
327 """
328 Return relevant upstream setups.
329
330 This method is meant to be used when we are checking task dependencies where we need
331 to wait for all the upstream setups to complete before we can run the task.
332 """
333 for task in self.get_upstreams_only_setups_and_teardowns():
334 if task.is_setup:
335 yield task
336
337 def serialize_for_task_group(self) -> tuple[DagAttributeTypes, Any]:
338 """Serialize a task group's content; used by TaskGroupSerialization."""
339 raise NotImplementedError()