1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements. See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership. The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License. You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied. See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18from __future__ import annotations
19
20from typing import TYPE_CHECKING, Generic, TypeVar
21
22import structlog
23
24if TYPE_CHECKING:
25 from collections.abc import Collection, Iterable
26
27 from ..logging.types import Logger
28
29Dag = TypeVar("Dag")
30Task = TypeVar("Task")
31TaskGroup = TypeVar("TaskGroup")
32
33
34class GenericDAGNode(Generic[Dag, Task, TaskGroup]):
35 """
36 Generic class for a node in the graph of a workflow.
37
38 A node may be an operator or task group, either mapped or unmapped.
39 """
40
41 dag: Dag | None
42 task_group: TaskGroup | None
43 downstream_group_ids: set[str | None]
44 upstream_task_ids: set[str]
45 downstream_task_ids: set[str]
46
47 _log_config_logger_name: str | None = None
48 _logger_name: str | None = None
49 _cached_logger: Logger | None = None
50
51 def __init__(self):
52 super().__init__()
53 self.upstream_task_ids = set()
54 self.downstream_task_ids = set()
55
56 @property
57 def log(self) -> Logger:
58 if self._cached_logger is not None:
59 return self._cached_logger
60
61 typ = type(self)
62
63 logger_name: str = (
64 self._logger_name if self._logger_name is not None else f"{typ.__module__}.{typ.__qualname__}"
65 )
66
67 if self._log_config_logger_name:
68 logger_name = (
69 f"{self._log_config_logger_name}.{logger_name}"
70 if logger_name
71 else self._log_config_logger_name
72 )
73
74 self._cached_logger = structlog.get_logger(logger_name)
75 return self._cached_logger
76
77 @property
78 def dag_id(self) -> str:
79 if self.dag:
80 return self.dag.dag_id
81 return "_in_memory_dag_"
82
83 @property
84 def node_id(self) -> str:
85 raise NotImplementedError()
86
87 @property
88 def label(self) -> str | None:
89 tg = self.task_group
90 if tg and tg.node_id and tg.prefix_group_id:
91 # "task_group_id.task_id" -> "task_id"
92 return self.node_id[len(tg.node_id) + 1 :]
93 return self.node_id
94
95 @property
96 def upstream_list(self) -> Iterable[Task]:
97 if not self.dag:
98 raise RuntimeError(f"Operator {self} has not been assigned to a Dag yet")
99 return [self.dag.get_task(tid) for tid in self.upstream_task_ids]
100
101 @property
102 def downstream_list(self) -> Iterable[Task]:
103 if not self.dag:
104 raise RuntimeError(f"Operator {self} has not been assigned to a Dag yet")
105 return [self.dag.get_task(tid) for tid in self.downstream_task_ids]
106
107 def has_dag(self) -> bool:
108 return self.dag is not None
109
110 def get_dag(self) -> Dag | None:
111 return self.dag
112
113 def get_direct_relative_ids(self, upstream: bool = False) -> set[str]:
114 """Get set of the direct relative ids to the current task, upstream or downstream."""
115 if upstream:
116 return self.upstream_task_ids
117 return self.downstream_task_ids
118
119 def get_direct_relatives(self, upstream: bool = False) -> Iterable[Task]:
120 """Get list of the direct relatives to the current task, upstream or downstream."""
121 if upstream:
122 return self.upstream_list
123 return self.downstream_list
124
125 def get_flat_relative_ids(self, *, upstream: bool = False) -> set[str]:
126 """
127 Get a flat set of relative IDs, upstream or downstream.
128
129 Will recurse each relative found in the direction specified.
130
131 :param upstream: Whether to look for upstream or downstream relatives.
132 """
133 dag = self.get_dag()
134 if not dag:
135 return set()
136
137 relatives: set[str] = set()
138
139 # This is intentionally implemented as a loop, instead of calling
140 # get_direct_relative_ids() recursively, since Python has significant
141 # limitation on stack level, and a recursive implementation can blow up
142 # if a DAG contains very long routes.
143 task_ids_to_trace = self.get_direct_relative_ids(upstream)
144 while task_ids_to_trace:
145 task_ids_to_trace_next: set[str] = set()
146 for task_id in task_ids_to_trace:
147 if task_id in relatives:
148 continue
149 task_ids_to_trace_next.update(dag.task_dict[task_id].get_direct_relative_ids(upstream))
150 relatives.add(task_id)
151 task_ids_to_trace = task_ids_to_trace_next
152
153 return relatives
154
155 def get_flat_relatives(self, upstream: bool = False) -> Collection[Task]:
156 """Get a flat list of relatives, either upstream or downstream."""
157 dag = self.get_dag()
158 if not dag:
159 return set()
160 return [dag.task_dict[task_id] for task_id in self.get_flat_relative_ids(upstream=upstream)]
161
162 def get_upstreams_follow_setups(self) -> Iterable[Task]:
163 """All upstreams and, for each upstream setup, its respective teardowns."""
164 for task in self.get_flat_relatives(upstream=True):
165 yield task
166 if task.is_setup:
167 for t in task.downstream_list:
168 if t.is_teardown and t != self:
169 yield t
170
171 def get_upstreams_only_setups_and_teardowns(self) -> Iterable[Task]:
172 """
173 Only *relevant* upstream setups and their teardowns.
174
175 This method is meant to be used when we are clearing the task (non-upstream) and we need
176 to add in the *relevant* setups and their teardowns.
177
178 Relevant in this case means, the setup has a teardown that is downstream of ``self``,
179 or the setup has no teardowns.
180 """
181 downstream_teardown_ids = {
182 x.task_id for x in self.get_flat_relatives(upstream=False) if x.is_teardown
183 }
184 for task in self.get_flat_relatives(upstream=True):
185 if not task.is_setup:
186 continue
187 has_no_teardowns = not any(x.is_teardown for x in task.downstream_list)
188 # if task has no teardowns or has teardowns downstream of self
189 if has_no_teardowns or task.downstream_task_ids.intersection(downstream_teardown_ids):
190 yield task
191 for t in task.downstream_list:
192 if t.is_teardown and t != self:
193 yield t
194
195 def get_upstreams_only_setups(self) -> Iterable[Task]:
196 """
197 Return relevant upstream setups.
198
199 This method is meant to be used when we are checking task dependencies where we need
200 to wait for all the upstream setups to complete before we can run the task.
201 """
202 for task in self.get_upstreams_only_setups_and_teardowns():
203 if task.is_setup:
204 yield task