1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements. See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership. The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License. You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied. See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18from __future__ import annotations
19
20from typing import TYPE_CHECKING, Any, Generic, Protocol, TypeVar
21
22import structlog
23
24if TYPE_CHECKING:
25 import sys
26 from collections.abc import Collection, Iterable
27
28 # Replicate `airflow.typing_compat.Self` to avoid illegal imports
29 if sys.version_info >= (3, 11):
30 from typing import Self
31 else:
32 from typing_extensions import Self
33
34 from ..logging.types import Logger
35
36
37class DagProtocol(Protocol):
38 """Protocol defining the minimum interface required for Dag generic type."""
39
40 dag_id: str
41 task_dict: dict[str, Any]
42
43 def get_task(self, tid: str) -> Any:
44 """Retrieve a task by its task ID."""
45 ...
46
47
48class TaskProtocol(Protocol):
49 """Protocol defining the minimum interface required for Task generic type."""
50
51 task_id: str
52 is_setup: bool
53 is_teardown: bool
54 downstream_list: Iterable[Self]
55 downstream_task_ids: set[str]
56
57
58class TaskGroupProtocol(Protocol):
59 """Protocol defining the minimum interface required for TaskGroup generic type."""
60
61 node_id: str
62 prefix_group_id: bool
63
64
65Dag = TypeVar("Dag", bound=DagProtocol)
66Task = TypeVar("Task", bound=TaskProtocol)
67TaskGroup = TypeVar("TaskGroup", bound=TaskGroupProtocol)
68
69
70class GenericDAGNode(Generic[Dag, Task, TaskGroup]):
71 """
72 Generic class for a node in the graph of a workflow.
73
74 A node may be an operator or task group, either mapped or unmapped.
75 """
76
77 dag: Dag | None
78 task_group: TaskGroup | None
79 downstream_group_ids: set[str | None]
80 upstream_task_ids: set[str]
81 downstream_task_ids: set[str]
82
83 _log_config_logger_name: str | None = None
84 _logger_name: str | None = None
85 _cached_logger: Logger | None = None
86
87 def __init__(self):
88 super().__init__()
89 self.upstream_task_ids = set()
90 self.downstream_task_ids = set()
91
92 @property
93 def log(self) -> Logger:
94 if self._cached_logger is not None:
95 return self._cached_logger
96
97 typ = type(self)
98
99 logger_name: str = (
100 self._logger_name if self._logger_name is not None else f"{typ.__module__}.{typ.__qualname__}"
101 )
102
103 if self._log_config_logger_name:
104 logger_name = (
105 f"{self._log_config_logger_name}.{logger_name}"
106 if logger_name
107 else self._log_config_logger_name
108 )
109
110 self._cached_logger = structlog.get_logger(logger_name)
111 return self._cached_logger
112
113 @property
114 def dag_id(self) -> str:
115 if self.dag:
116 return self.dag.dag_id
117 return "_in_memory_dag_"
118
119 @property
120 def node_id(self) -> str:
121 raise NotImplementedError()
122
123 @property
124 def label(self) -> str | None:
125 tg = self.task_group
126 if tg and tg.node_id and tg.prefix_group_id:
127 # "task_group_id.task_id" -> "task_id"
128 return self.node_id[len(tg.node_id) + 1 :]
129 return self.node_id
130
131 @property
132 def upstream_list(self) -> Iterable[Task]:
133 if not self.dag:
134 raise RuntimeError(f"Operator {self} has not been assigned to a Dag yet")
135 return [self.dag.get_task(tid) for tid in self.upstream_task_ids]
136
137 @property
138 def downstream_list(self) -> Iterable[Task]:
139 if not self.dag:
140 raise RuntimeError(f"Operator {self} has not been assigned to a Dag yet")
141 return [self.dag.get_task(tid) for tid in self.downstream_task_ids]
142
143 def has_dag(self) -> bool:
144 return self.dag is not None
145
146 def get_dag(self) -> Dag | None:
147 return self.dag
148
149 def get_direct_relative_ids(self, upstream: bool = False) -> set[str]:
150 """Get set of the direct relative ids to the current task, upstream or downstream."""
151 if upstream:
152 return self.upstream_task_ids
153 return self.downstream_task_ids
154
155 def get_direct_relatives(self, upstream: bool = False) -> Iterable[Task]:
156 """Get list of the direct relatives to the current task, upstream or downstream."""
157 if upstream:
158 return self.upstream_list
159 return self.downstream_list
160
161 def get_flat_relative_ids(self, *, upstream: bool = False, depth: int | None = None) -> set[str]:
162 """
163 Get a flat set of relative IDs, upstream or downstream.
164
165 Will recurse each relative found in the direction specified.
166
167 :param upstream: Whether to look for upstream or downstream relatives.
168 :param depth: Maximum number of levels to traverse. If None, traverses all levels.
169 Must be non-negative.
170 """
171 if depth is not None and depth < 0:
172 raise ValueError(f"depth must be non-negative, got {depth}")
173
174 dag = self.get_dag()
175 if not dag:
176 return set()
177
178 relatives: set[str] = set()
179
180 # This is intentionally implemented as a loop, instead of calling
181 # get_direct_relative_ids() recursively, since Python has significant
182 # limitation on stack level, and a recursive implementation can blow up
183 # if a DAG contains very long routes.
184 task_ids_to_trace = self.get_direct_relative_ids(upstream)
185 levels_remaining = depth
186 while task_ids_to_trace:
187 # if depth is set we have bounded traversal and should break when
188 # there are no more levels remaining
189 if levels_remaining is not None and levels_remaining <= 0:
190 break
191 task_ids_to_trace_next: set[str] = set()
192 for task_id in task_ids_to_trace:
193 if task_id in relatives:
194 continue
195 task_ids_to_trace_next.update(dag.task_dict[task_id].get_direct_relative_ids(upstream))
196 relatives.add(task_id)
197 task_ids_to_trace = task_ids_to_trace_next
198 if levels_remaining is not None:
199 levels_remaining -= 1
200
201 return relatives
202
203 def get_flat_relatives(self, upstream: bool = False, depth: int | None = None) -> Collection[Task]:
204 """
205 Get a flat list of relatives, either upstream or downstream.
206
207 :param upstream: Whether to look for upstream or downstream relatives.
208 :param depth: Maximum number of levels to traverse. If None, traverses all levels.
209 Must be non-negative.
210 """
211 dag = self.get_dag()
212 if not dag:
213 return set()
214 return [
215 dag.task_dict[task_id] for task_id in self.get_flat_relative_ids(upstream=upstream, depth=depth)
216 ]
217
218 def get_upstreams_follow_setups(self, depth: int | None = None) -> Iterable[Task]:
219 """
220 All upstreams and, for each upstream setup, its respective teardowns.
221
222 :param depth: Maximum number of levels to traverse. If None, traverses all levels.
223 Must be non-negative.
224 """
225 for task in self.get_flat_relatives(upstream=True, depth=depth):
226 yield task
227 if task.is_setup:
228 for t in task.downstream_list:
229 if t.is_teardown and t != self:
230 yield t
231
232 def get_upstreams_only_setups_and_teardowns(self) -> Iterable[Task]:
233 """
234 Only *relevant* upstream setups and their teardowns.
235
236 This method is meant to be used when we are clearing the task (non-upstream) and we need
237 to add in the *relevant* setups and their teardowns.
238
239 Relevant in this case means, the setup has a teardown that is downstream of ``self``,
240 or the setup has no teardowns.
241 """
242 downstream_teardown_ids = {
243 x.task_id for x in self.get_flat_relatives(upstream=False) if x.is_teardown
244 }
245 for task in self.get_flat_relatives(upstream=True):
246 if not task.is_setup:
247 continue
248 has_no_teardowns = not any(x.is_teardown for x in task.downstream_list)
249 # if task has no teardowns or has teardowns downstream of self
250 if has_no_teardowns or task.downstream_task_ids.intersection(downstream_teardown_ids):
251 yield task
252 for t in task.downstream_list:
253 if t.is_teardown and t != self:
254 yield t
255
256 def get_upstreams_only_setups(self) -> Iterable[Task]:
257 """
258 Return relevant upstream setups.
259
260 This method is meant to be used when we are checking task dependencies where we need
261 to wait for all the upstream setups to complete before we can run the task.
262 """
263 for task in self.get_upstreams_only_setups_and_teardowns():
264 if task.is_setup:
265 yield task