Coverage for /pythoncovmergedfiles/medio/medio/src/airflow/build/lib/airflow/models/taskmixin.py: 61%
140 statements
« prev ^ index » next coverage.py v7.0.1, created at 2022-12-25 06:11 +0000
« prev ^ index » next coverage.py v7.0.1, created at 2022-12-25 06:11 +0000
1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements. See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership. The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License. You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied. See the License for the
15# specific language governing permissions and limitations
16# under the License.
17from __future__ import annotations
19import warnings
20from abc import ABCMeta, abstractmethod
21from typing import TYPE_CHECKING, Any, Iterable, Sequence
23import pendulum
25from airflow.exceptions import AirflowException, RemovedInAirflow3Warning
26from airflow.serialization.enums import DagAttributeTypes
28if TYPE_CHECKING:
29 from logging import Logger
31 from airflow.models.dag import DAG
32 from airflow.models.operator import Operator
33 from airflow.utils.edgemodifier import EdgeModifier
34 from airflow.utils.task_group import TaskGroup
37class DependencyMixin:
38 """Mixing implementing common dependency setting methods methods like >> and <<."""
40 @property
41 def roots(self) -> Sequence[DependencyMixin]:
42 """
43 List of root nodes -- ones with no upstream dependencies.
45 a.k.a. the "start" of this sub-graph
46 """
47 raise NotImplementedError()
49 @property
50 def leaves(self) -> Sequence[DependencyMixin]:
51 """
52 List of leaf nodes -- ones with only upstream dependencies.
54 a.k.a. the "end" of this sub-graph
55 """
56 raise NotImplementedError()
58 @abstractmethod
59 def set_upstream(self, other: DependencyMixin | Sequence[DependencyMixin]):
60 """Set a task or a task list to be directly upstream from the current task."""
61 raise NotImplementedError()
63 @abstractmethod
64 def set_downstream(self, other: DependencyMixin | Sequence[DependencyMixin]):
65 """Set a task or a task list to be directly downstream from the current task."""
66 raise NotImplementedError()
68 def update_relative(self, other: DependencyMixin, upstream=True) -> None:
69 """
70 Update relationship information about another TaskMixin. Default is no-op.
71 Override if necessary.
72 """
74 def __lshift__(self, other: DependencyMixin | Sequence[DependencyMixin]):
75 """Implements Task << Task"""
76 self.set_upstream(other)
77 return other
79 def __rshift__(self, other: DependencyMixin | Sequence[DependencyMixin]):
80 """Implements Task >> Task"""
81 self.set_downstream(other)
82 return other
84 def __rrshift__(self, other: DependencyMixin | Sequence[DependencyMixin]):
85 """Called for Task >> [Task] because list don't have __rshift__ operators."""
86 self.__lshift__(other)
87 return self
89 def __rlshift__(self, other: DependencyMixin | Sequence[DependencyMixin]):
90 """Called for Task << [Task] because list don't have __lshift__ operators."""
91 self.__rshift__(other)
92 return self
95class TaskMixin(DependencyMixin):
96 """:meta private:"""
98 def __init_subclass__(cls) -> None:
99 warnings.warn(
100 f"TaskMixin has been renamed to DependencyMixin, please update {cls.__name__}",
101 category=RemovedInAirflow3Warning,
102 stacklevel=2,
103 )
104 return super().__init_subclass__()
107class DAGNode(DependencyMixin, metaclass=ABCMeta):
108 """
109 A base class for a node in the graph of a workflow -- an Operator or a Task Group, either mapped or
110 unmapped.
111 """
113 dag: DAG | None = None
114 task_group: TaskGroup | None = None
115 """The task_group that contains this node"""
117 @property
118 @abstractmethod
119 def node_id(self) -> str:
120 raise NotImplementedError()
122 @property
123 def label(self) -> str | None:
124 tg = self.task_group
125 if tg and tg.node_id and tg.prefix_group_id:
126 # "task_group_id.task_id" -> "task_id"
127 return self.node_id[len(tg.node_id) + 1 :]
128 return self.node_id
130 start_date: pendulum.DateTime | None
131 end_date: pendulum.DateTime | None
132 upstream_task_ids: set[str]
133 downstream_task_ids: set[str]
135 def has_dag(self) -> bool:
136 return self.dag is not None
138 @property
139 def dag_id(self) -> str:
140 """Returns dag id if it has one or an adhoc/meaningless ID"""
141 if self.dag:
142 return self.dag.dag_id
143 return "_in_memory_dag_"
145 @property
146 def log(self) -> Logger:
147 raise NotImplementedError()
149 @property
150 @abstractmethod
151 def roots(self) -> Sequence[DAGNode]:
152 raise NotImplementedError()
154 @property
155 @abstractmethod
156 def leaves(self) -> Sequence[DAGNode]:
157 raise NotImplementedError()
159 def _set_relatives(
160 self,
161 task_or_task_list: DependencyMixin | Sequence[DependencyMixin],
162 upstream: bool = False,
163 edge_modifier: EdgeModifier | None = None,
164 ) -> None:
165 """Sets relatives for the task or task list."""
166 from airflow.models.baseoperator import BaseOperator
167 from airflow.models.mappedoperator import MappedOperator
168 from airflow.models.operator import Operator
170 if not isinstance(task_or_task_list, Sequence):
171 task_or_task_list = [task_or_task_list]
173 task_list: list[Operator] = []
174 for task_object in task_or_task_list:
175 task_object.update_relative(self, not upstream)
176 relatives = task_object.leaves if upstream else task_object.roots
177 for task in relatives:
178 if not isinstance(task, (BaseOperator, MappedOperator)):
179 raise AirflowException(
180 f"Relationships can only be set between Operators; received {task.__class__.__name__}"
181 )
182 task_list.append(task)
184 # relationships can only be set if the tasks share a single DAG. Tasks
185 # without a DAG are assigned to that DAG.
186 dags: set[DAG] = {task.dag for task in [*self.roots, *task_list] if task.has_dag() and task.dag}
188 if len(dags) > 1:
189 raise AirflowException(f"Tried to set relationships between tasks in more than one DAG: {dags}")
190 elif len(dags) == 1:
191 dag = dags.pop()
192 else:
193 raise AirflowException(
194 f"Tried to create relationships between tasks that don't have DAGs yet. "
195 f"Set the DAG for at least one task and try again: {[self, *task_list]}"
196 )
198 if not self.has_dag():
199 # If this task does not yet have a dag, add it to the same dag as the other task.
200 self.dag = dag
202 def add_only_new(obj, item_set: set[str], item: str) -> None:
203 """Adds only new items to item set"""
204 if item in item_set:
205 self.log.warning("Dependency %s, %s already registered for DAG: %s", obj, item, dag.dag_id)
206 else:
207 item_set.add(item)
209 for task in task_list:
210 if dag and not task.has_dag():
211 # If the other task does not yet have a dag, add it to the same dag as this task and
212 dag.add_task(task)
213 if upstream:
214 add_only_new(task, task.downstream_task_ids, self.node_id)
215 add_only_new(self, self.upstream_task_ids, task.node_id)
216 if edge_modifier:
217 edge_modifier.add_edge_info(self.dag, task.node_id, self.node_id)
218 else:
219 add_only_new(self, self.downstream_task_ids, task.node_id)
220 add_only_new(task, task.upstream_task_ids, self.node_id)
221 if edge_modifier:
222 edge_modifier.add_edge_info(self.dag, self.node_id, task.node_id)
224 def set_downstream(
225 self,
226 task_or_task_list: DependencyMixin | Sequence[DependencyMixin],
227 edge_modifier: EdgeModifier | None = None,
228 ) -> None:
229 """Set a node (or nodes) to be directly downstream from the current node."""
230 self._set_relatives(task_or_task_list, upstream=False, edge_modifier=edge_modifier)
232 def set_upstream(
233 self,
234 task_or_task_list: DependencyMixin | Sequence[DependencyMixin],
235 edge_modifier: EdgeModifier | None = None,
236 ) -> None:
237 """Set a node (or nodes) to be directly upstream from the current node."""
238 self._set_relatives(task_or_task_list, upstream=True, edge_modifier=edge_modifier)
240 @property
241 def downstream_list(self) -> Iterable[Operator]:
242 """List of nodes directly downstream"""
243 if not self.dag:
244 raise AirflowException(f"Operator {self} has not been assigned to a DAG yet")
245 return [self.dag.get_task(tid) for tid in self.downstream_task_ids]
247 @property
248 def upstream_list(self) -> Iterable[Operator]:
249 """List of nodes directly upstream"""
250 if not self.dag:
251 raise AirflowException(f"Operator {self} has not been assigned to a DAG yet")
252 return [self.dag.get_task(tid) for tid in self.upstream_task_ids]
254 def get_direct_relative_ids(self, upstream: bool = False) -> set[str]:
255 """
256 Get set of the direct relative ids to the current task, upstream or
257 downstream.
258 """
259 if upstream:
260 return self.upstream_task_ids
261 else:
262 return self.downstream_task_ids
264 def get_direct_relatives(self, upstream: bool = False) -> Iterable[DAGNode]:
265 """
266 Get list of the direct relatives to the current task, upstream or
267 downstream.
268 """
269 if upstream:
270 return self.upstream_list
271 else:
272 return self.downstream_list
274 def serialize_for_task_group(self) -> tuple[DagAttributeTypes, Any]:
275 """This is used by TaskGroupSerialization to serialize a task group's content."""
276 raise NotImplementedError()