1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements. See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership. The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License. You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied. See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18from __future__ import annotations
19
20from typing import TYPE_CHECKING, Generic, TypeVar
21
22import structlog
23
24if TYPE_CHECKING:
25 from collections.abc import Collection, Iterable
26
27 from ..logging.types import Logger
28
29Dag = TypeVar("Dag")
30Task = TypeVar("Task")
31TaskGroup = TypeVar("TaskGroup")
32
33
34class GenericDAGNode(Generic[Dag, Task, TaskGroup]):
35 """
36 Generic class for a node in the graph of a workflow.
37
38 A node may be an operator or task group, either mapped or unmapped.
39 """
40
41 dag: Dag | None
42 task_group: TaskGroup | None
43 downstream_group_ids: set[str | None]
44 upstream_task_ids: set[str]
45 downstream_task_ids: set[str]
46
47 _log_config_logger_name: str | None = None
48 _logger_name: str | None = None
49 _cached_logger: Logger | None = None
50
51 def __init__(self):
52 super().__init__()
53 self.upstream_task_ids = set()
54 self.downstream_task_ids = set()
55
56 @property
57 def log(self) -> Logger:
58 if self._cached_logger is not None:
59 return self._cached_logger
60
61 typ = type(self)
62
63 logger_name: str = (
64 self._logger_name if self._logger_name is not None else f"{typ.__module__}.{typ.__qualname__}"
65 )
66
67 if self._log_config_logger_name:
68 logger_name = (
69 f"{self._log_config_logger_name}.{logger_name}"
70 if logger_name
71 else self._log_config_logger_name
72 )
73
74 self._cached_logger = structlog.get_logger(logger_name)
75 return self._cached_logger
76
77 @property
78 def dag_id(self) -> str:
79 if self.dag:
80 return self.dag.dag_id
81 return "_in_memory_dag_"
82
83 @property
84 def node_id(self) -> str:
85 raise NotImplementedError()
86
87 @property
88 def label(self) -> str | None:
89 tg = self.task_group
90 if tg and tg.node_id and tg.prefix_group_id:
91 # "task_group_id.task_id" -> "task_id"
92 return self.node_id[len(tg.node_id) + 1 :]
93 return self.node_id
94
95 @property
96 def upstream_list(self) -> Iterable[Task]:
97 if not self.dag:
98 raise RuntimeError(f"Operator {self} has not been assigned to a Dag yet")
99 return [self.dag.get_task(tid) for tid in self.upstream_task_ids]
100
101 @property
102 def downstream_list(self) -> Iterable[Task]:
103 if not self.dag:
104 raise RuntimeError(f"Operator {self} has not been assigned to a Dag yet")
105 return [self.dag.get_task(tid) for tid in self.downstream_task_ids]
106
107 def has_dag(self) -> bool:
108 return self.dag is not None
109
110 def get_dag(self) -> Dag | None:
111 return self.dag
112
113 def get_direct_relative_ids(self, upstream: bool = False) -> set[str]:
114 """Get set of the direct relative ids to the current task, upstream or downstream."""
115 if upstream:
116 return self.upstream_task_ids
117 return self.downstream_task_ids
118
119 def get_direct_relatives(self, upstream: bool = False) -> Iterable[Task]:
120 """Get list of the direct relatives to the current task, upstream or downstream."""
121 if upstream:
122 return self.upstream_list
123 return self.downstream_list
124
125 def get_flat_relative_ids(self, *, upstream: bool = False, depth: int | None = None) -> set[str]:
126 """
127 Get a flat set of relative IDs, upstream or downstream.
128
129 Will recurse each relative found in the direction specified.
130
131 :param upstream: Whether to look for upstream or downstream relatives.
132 :param depth: Maximum number of levels to traverse. If None, traverses all levels.
133 Must be non-negative.
134 """
135 if depth is not None and depth < 0:
136 raise ValueError(f"depth must be non-negative, got {depth}")
137
138 dag = self.get_dag()
139 if not dag:
140 return set()
141
142 relatives: set[str] = set()
143
144 # This is intentionally implemented as a loop, instead of calling
145 # get_direct_relative_ids() recursively, since Python has significant
146 # limitation on stack level, and a recursive implementation can blow up
147 # if a DAG contains very long routes.
148 task_ids_to_trace = self.get_direct_relative_ids(upstream)
149 levels_remaining = depth
150 while task_ids_to_trace:
151 # if depth is set we have bounded traversal and should break when
152 # there are no more levels remaining
153 if levels_remaining is not None and levels_remaining <= 0:
154 break
155 task_ids_to_trace_next: set[str] = set()
156 for task_id in task_ids_to_trace:
157 if task_id in relatives:
158 continue
159 task_ids_to_trace_next.update(dag.task_dict[task_id].get_direct_relative_ids(upstream))
160 relatives.add(task_id)
161 task_ids_to_trace = task_ids_to_trace_next
162 if levels_remaining is not None:
163 levels_remaining -= 1
164
165 return relatives
166
167 def get_flat_relatives(self, upstream: bool = False, depth: int | None = None) -> Collection[Task]:
168 """
169 Get a flat list of relatives, either upstream or downstream.
170
171 :param upstream: Whether to look for upstream or downstream relatives.
172 :param depth: Maximum number of levels to traverse. If None, traverses all levels.
173 Must be non-negative.
174 """
175 dag = self.get_dag()
176 if not dag:
177 return set()
178 return [
179 dag.task_dict[task_id] for task_id in self.get_flat_relative_ids(upstream=upstream, depth=depth)
180 ]
181
182 def get_upstreams_follow_setups(self, depth: int | None = None) -> Iterable[Task]:
183 """
184 All upstreams and, for each upstream setup, its respective teardowns.
185
186 :param depth: Maximum number of levels to traverse. If None, traverses all levels.
187 Must be non-negative.
188 """
189 for task in self.get_flat_relatives(upstream=True, depth=depth):
190 yield task
191 if task.is_setup:
192 for t in task.downstream_list:
193 if t.is_teardown and t != self:
194 yield t
195
196 def get_upstreams_only_setups_and_teardowns(self) -> Iterable[Task]:
197 """
198 Only *relevant* upstream setups and their teardowns.
199
200 This method is meant to be used when we are clearing the task (non-upstream) and we need
201 to add in the *relevant* setups and their teardowns.
202
203 Relevant in this case means, the setup has a teardown that is downstream of ``self``,
204 or the setup has no teardowns.
205 """
206 downstream_teardown_ids = {
207 x.task_id for x in self.get_flat_relatives(upstream=False) if x.is_teardown
208 }
209 for task in self.get_flat_relatives(upstream=True):
210 if not task.is_setup:
211 continue
212 has_no_teardowns = not any(x.is_teardown for x in task.downstream_list)
213 # if task has no teardowns or has teardowns downstream of self
214 if has_no_teardowns or task.downstream_task_ids.intersection(downstream_teardown_ids):
215 yield task
216 for t in task.downstream_list:
217 if t.is_teardown and t != self:
218 yield t
219
220 def get_upstreams_only_setups(self) -> Iterable[Task]:
221 """
222 Return relevant upstream setups.
223
224 This method is meant to be used when we are checking task dependencies where we need
225 to wait for all the upstream setups to complete before we can run the task.
226 """
227 for task in self.get_upstreams_only_setups_and_teardowns():
228 if task.is_setup:
229 yield task