Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/airflow/models/taskmixin.py: 61%
136 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-07 06:35 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-07 06:35 +0000
1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements. See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership. The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License. You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied. See the License for the
15# specific language governing permissions and limitations
16# under the License.
17from __future__ import annotations
19import warnings
20from abc import ABCMeta, abstractmethod
21from typing import TYPE_CHECKING, Any, Iterable, Sequence
23import pendulum
25from airflow.exceptions import AirflowException, RemovedInAirflow3Warning
26from airflow.serialization.enums import DagAttributeTypes
28if TYPE_CHECKING:
29 from logging import Logger
31 from airflow.models.dag import DAG
32 from airflow.models.operator import Operator
33 from airflow.utils.edgemodifier import EdgeModifier
34 from airflow.utils.task_group import TaskGroup
37class DependencyMixin:
38 """Mixing implementing common dependency setting methods methods like >> and <<."""
40 @property
41 def roots(self) -> Sequence[DependencyMixin]:
42 """
43 List of root nodes -- ones with no upstream dependencies.
45 a.k.a. the "start" of this sub-graph
46 """
47 raise NotImplementedError()
49 @property
50 def leaves(self) -> Sequence[DependencyMixin]:
51 """
52 List of leaf nodes -- ones with only upstream dependencies.
54 a.k.a. the "end" of this sub-graph
55 """
56 raise NotImplementedError()
58 @abstractmethod
59 def set_upstream(
60 self, other: DependencyMixin | Sequence[DependencyMixin], edge_modifier: EdgeModifier | None = None
61 ):
62 """Set a task or a task list to be directly upstream from the current task."""
63 raise NotImplementedError()
65 @abstractmethod
66 def set_downstream(
67 self, other: DependencyMixin | Sequence[DependencyMixin], edge_modifier: EdgeModifier | None = None
68 ):
69 """Set a task or a task list to be directly downstream from the current task."""
70 raise NotImplementedError()
72 def update_relative(
73 self, other: DependencyMixin, upstream: bool = True, edge_modifier: EdgeModifier | None = None
74 ) -> None:
75 """
76 Update relationship information about another TaskMixin. Default is no-op.
77 Override if necessary.
78 """
80 def __lshift__(self, other: DependencyMixin | Sequence[DependencyMixin]):
81 """Implements Task << Task."""
82 self.set_upstream(other)
83 return other
85 def __rshift__(self, other: DependencyMixin | Sequence[DependencyMixin]):
86 """Implements Task >> Task."""
87 self.set_downstream(other)
88 return other
90 def __rrshift__(self, other: DependencyMixin | Sequence[DependencyMixin]):
91 """Called for Task >> [Task] because list don't have __rshift__ operators."""
92 self.__lshift__(other)
93 return self
95 def __rlshift__(self, other: DependencyMixin | Sequence[DependencyMixin]):
96 """Called for Task << [Task] because list don't have __lshift__ operators."""
97 self.__rshift__(other)
98 return self
101class TaskMixin(DependencyMixin):
102 """Mixin to provide task-related things.
104 :meta private:
105 """
107 def __init_subclass__(cls) -> None:
108 warnings.warn(
109 f"TaskMixin has been renamed to DependencyMixin, please update {cls.__name__}",
110 category=RemovedInAirflow3Warning,
111 stacklevel=2,
112 )
113 return super().__init_subclass__()
116class DAGNode(DependencyMixin, metaclass=ABCMeta):
117 """
118 A base class for a node in the graph of a workflow -- an Operator or a Task Group, either mapped or
119 unmapped.
120 """
122 dag: DAG | None = None
123 task_group: TaskGroup | None = None
124 """The task_group that contains this node"""
126 @property
127 @abstractmethod
128 def node_id(self) -> str:
129 raise NotImplementedError()
131 @property
132 def label(self) -> str | None:
133 tg = self.task_group
134 if tg and tg.node_id and tg.prefix_group_id:
135 # "task_group_id.task_id" -> "task_id"
136 return self.node_id[len(tg.node_id) + 1 :]
137 return self.node_id
139 start_date: pendulum.DateTime | None
140 end_date: pendulum.DateTime | None
141 upstream_task_ids: set[str]
142 downstream_task_ids: set[str]
144 def has_dag(self) -> bool:
145 return self.dag is not None
147 @property
148 def dag_id(self) -> str:
149 """Returns dag id if it has one or an adhoc/meaningless ID."""
150 if self.dag:
151 return self.dag.dag_id
152 return "_in_memory_dag_"
154 @property
155 def log(self) -> Logger:
156 raise NotImplementedError()
158 @property
159 @abstractmethod
160 def roots(self) -> Sequence[DAGNode]:
161 raise NotImplementedError()
163 @property
164 @abstractmethod
165 def leaves(self) -> Sequence[DAGNode]:
166 raise NotImplementedError()
168 def _set_relatives(
169 self,
170 task_or_task_list: DependencyMixin | Sequence[DependencyMixin],
171 upstream: bool = False,
172 edge_modifier: EdgeModifier | None = None,
173 ) -> None:
174 """Sets relatives for the task or task list."""
175 from airflow.models.baseoperator import BaseOperator
176 from airflow.models.mappedoperator import MappedOperator
177 from airflow.models.operator import Operator
179 if not isinstance(task_or_task_list, Sequence):
180 task_or_task_list = [task_or_task_list]
182 task_list: list[Operator] = []
183 for task_object in task_or_task_list:
184 task_object.update_relative(self, not upstream, edge_modifier=edge_modifier)
185 relatives = task_object.leaves if upstream else task_object.roots
186 for task in relatives:
187 if not isinstance(task, (BaseOperator, MappedOperator)):
188 raise AirflowException(
189 f"Relationships can only be set between Operators; received {task.__class__.__name__}"
190 )
191 task_list.append(task)
193 # relationships can only be set if the tasks share a single DAG. Tasks
194 # without a DAG are assigned to that DAG.
195 dags: set[DAG] = {task.dag for task in [*self.roots, *task_list] if task.has_dag() and task.dag}
197 if len(dags) > 1:
198 raise AirflowException(f"Tried to set relationships between tasks in more than one DAG: {dags}")
199 elif len(dags) == 1:
200 dag = dags.pop()
201 else:
202 raise AirflowException(
203 f"Tried to create relationships between tasks that don't have DAGs yet. "
204 f"Set the DAG for at least one task and try again: {[self, *task_list]}"
205 )
207 if not self.has_dag():
208 # If this task does not yet have a dag, add it to the same dag as the other task.
209 self.dag = dag
211 for task in task_list:
212 if dag and not task.has_dag():
213 # If the other task does not yet have a dag, add it to the same dag as this task and
214 dag.add_task(task)
215 if upstream:
216 task.downstream_task_ids.add(self.node_id)
217 self.upstream_task_ids.add(task.node_id)
218 if edge_modifier:
219 edge_modifier.add_edge_info(self.dag, task.node_id, self.node_id)
220 else:
221 self.downstream_task_ids.add(task.node_id)
222 task.upstream_task_ids.add(self.node_id)
223 if edge_modifier:
224 edge_modifier.add_edge_info(self.dag, self.node_id, task.node_id)
226 def set_downstream(
227 self,
228 task_or_task_list: DependencyMixin | Sequence[DependencyMixin],
229 edge_modifier: EdgeModifier | None = None,
230 ) -> None:
231 """Set a node (or nodes) to be directly downstream from the current node."""
232 self._set_relatives(task_or_task_list, upstream=False, edge_modifier=edge_modifier)
234 def set_upstream(
235 self,
236 task_or_task_list: DependencyMixin | Sequence[DependencyMixin],
237 edge_modifier: EdgeModifier | None = None,
238 ) -> None:
239 """Set a node (or nodes) to be directly upstream from the current node."""
240 self._set_relatives(task_or_task_list, upstream=True, edge_modifier=edge_modifier)
242 @property
243 def downstream_list(self) -> Iterable[Operator]:
244 """List of nodes directly downstream."""
245 if not self.dag:
246 raise AirflowException(f"Operator {self} has not been assigned to a DAG yet")
247 return [self.dag.get_task(tid) for tid in self.downstream_task_ids]
249 @property
250 def upstream_list(self) -> Iterable[Operator]:
251 """List of nodes directly upstream."""
252 if not self.dag:
253 raise AirflowException(f"Operator {self} has not been assigned to a DAG yet")
254 return [self.dag.get_task(tid) for tid in self.upstream_task_ids]
256 def get_direct_relative_ids(self, upstream: bool = False) -> set[str]:
257 """
258 Get set of the direct relative ids to the current task, upstream or
259 downstream.
260 """
261 if upstream:
262 return self.upstream_task_ids
263 else:
264 return self.downstream_task_ids
266 def get_direct_relatives(self, upstream: bool = False) -> Iterable[DAGNode]:
267 """
268 Get list of the direct relatives to the current task, upstream or
269 downstream.
270 """
271 if upstream:
272 return self.upstream_list
273 else:
274 return self.downstream_list
276 def serialize_for_task_group(self) -> tuple[DagAttributeTypes, Any]:
277 """This is used by TaskGroupSerialization to serialize a task group's content."""
278 raise NotImplementedError()