Coverage for /pythoncovmergedfiles/medio/medio/src/airflow/airflow/models/taskmixin.py: 56%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements. See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership. The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License. You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied. See the License for the
15# specific language governing permissions and limitations
16# under the License.
17from __future__ import annotations
19import warnings
20from abc import ABCMeta, abstractmethod
21from typing import TYPE_CHECKING, Any, Iterable, Sequence
23from airflow.exceptions import AirflowException, RemovedInAirflow3Warning
24from airflow.utils.types import NOTSET
26if TYPE_CHECKING:
27 from logging import Logger
29 import pendulum
31 from airflow.models.baseoperator import BaseOperator
32 from airflow.models.dag import DAG
33 from airflow.models.operator import Operator
34 from airflow.serialization.enums import DagAttributeTypes
35 from airflow.utils.edgemodifier import EdgeModifier
36 from airflow.utils.task_group import TaskGroup
37 from airflow.utils.types import ArgNotSet
40class DependencyMixin:
41 """Mixing implementing common dependency setting methods like >> and <<."""
43 @property
44 def roots(self) -> Sequence[DependencyMixin]:
45 """
46 List of root nodes -- ones with no upstream dependencies.
48 a.k.a. the "start" of this sub-graph
49 """
50 raise NotImplementedError()
52 @property
53 def leaves(self) -> Sequence[DependencyMixin]:
54 """
55 List of leaf nodes -- ones with only upstream dependencies.
57 a.k.a. the "end" of this sub-graph
58 """
59 raise NotImplementedError()
61 @abstractmethod
62 def set_upstream(
63 self, other: DependencyMixin | Sequence[DependencyMixin], edge_modifier: EdgeModifier | None = None
64 ):
65 """Set a task or a task list to be directly upstream from the current task."""
66 raise NotImplementedError()
68 @abstractmethod
69 def set_downstream(
70 self, other: DependencyMixin | Sequence[DependencyMixin], edge_modifier: EdgeModifier | None = None
71 ):
72 """Set a task or a task list to be directly downstream from the current task."""
73 raise NotImplementedError()
75 def as_setup(self) -> DependencyMixin:
76 """Mark a task as setup task."""
77 raise NotImplementedError()
79 def as_teardown(
80 self,
81 *,
82 setups: BaseOperator | Iterable[BaseOperator] | ArgNotSet = NOTSET,
83 on_failure_fail_dagrun=NOTSET,
84 ) -> DependencyMixin:
85 """Mark a task as teardown and set its setups as direct relatives."""
86 raise NotImplementedError()
88 def update_relative(
89 self, other: DependencyMixin, upstream: bool = True, edge_modifier: EdgeModifier | None = None
90 ) -> None:
91 """
92 Update relationship information about another TaskMixin. Default is no-op.
94 Override if necessary.
95 """
97 def __lshift__(self, other: DependencyMixin | Sequence[DependencyMixin]):
98 """Implement Task << Task."""
99 self.set_upstream(other)
100 return other
102 def __rshift__(self, other: DependencyMixin | Sequence[DependencyMixin]):
103 """Implement Task >> Task."""
104 self.set_downstream(other)
105 return other
107 def __rrshift__(self, other: DependencyMixin | Sequence[DependencyMixin]):
108 """Implement Task >> [Task] because list don't have __rshift__ operators."""
109 self.__lshift__(other)
110 return self
112 def __rlshift__(self, other: DependencyMixin | Sequence[DependencyMixin]):
113 """Implement Task << [Task] because list don't have __lshift__ operators."""
114 self.__rshift__(other)
115 return self
117 @classmethod
118 def _iter_references(cls, obj: Any) -> Iterable[tuple[DependencyMixin, str]]:
119 from airflow.models.baseoperator import AbstractOperator
120 from airflow.utils.mixins import ResolveMixin
122 if isinstance(obj, AbstractOperator):
123 yield obj, "operator"
124 elif isinstance(obj, ResolveMixin):
125 yield from obj.iter_references()
126 elif isinstance(obj, Sequence):
127 for o in obj:
128 yield from cls._iter_references(o)
131class TaskMixin(DependencyMixin):
132 """Mixin to provide task-related things.
134 :meta private:
135 """
137 def __init_subclass__(cls) -> None:
138 warnings.warn(
139 f"TaskMixin has been renamed to DependencyMixin, please update {cls.__name__}",
140 category=RemovedInAirflow3Warning,
141 stacklevel=2,
142 )
143 return super().__init_subclass__()
146class DAGNode(DependencyMixin, metaclass=ABCMeta):
147 """
148 A base class for a node in the graph of a workflow.
150 A node may be an Operator or a Task Group, either mapped or unmapped.
151 """
153 dag: DAG | None = None
154 task_group: TaskGroup | None = None
155 """The task_group that contains this node"""
157 @property
158 @abstractmethod
159 def node_id(self) -> str:
160 raise NotImplementedError()
162 @property
163 def label(self) -> str | None:
164 tg = self.task_group
165 if tg and tg.node_id and tg.prefix_group_id:
166 # "task_group_id.task_id" -> "task_id"
167 return self.node_id[len(tg.node_id) + 1 :]
168 return self.node_id
170 start_date: pendulum.DateTime | None
171 end_date: pendulum.DateTime | None
172 upstream_task_ids: set[str]
173 downstream_task_ids: set[str]
175 def has_dag(self) -> bool:
176 return self.dag is not None
178 @property
179 def dag_id(self) -> str:
180 """Returns dag id if it has one or an adhoc/meaningless ID."""
181 if self.dag:
182 return self.dag.dag_id
183 return "_in_memory_dag_"
185 @property
186 def log(self) -> Logger:
187 raise NotImplementedError()
189 @property
190 @abstractmethod
191 def roots(self) -> Sequence[DAGNode]:
192 raise NotImplementedError()
194 @property
195 @abstractmethod
196 def leaves(self) -> Sequence[DAGNode]:
197 raise NotImplementedError()
199 def _set_relatives(
200 self,
201 task_or_task_list: DependencyMixin | Sequence[DependencyMixin],
202 upstream: bool = False,
203 edge_modifier: EdgeModifier | None = None,
204 ) -> None:
205 """Set relatives for the task or task list."""
206 from airflow.models.baseoperator import BaseOperator
207 from airflow.models.mappedoperator import MappedOperator
209 if not isinstance(task_or_task_list, Sequence):
210 task_or_task_list = [task_or_task_list]
212 task_list: list[Operator] = []
213 for task_object in task_or_task_list:
214 task_object.update_relative(self, not upstream, edge_modifier=edge_modifier)
215 relatives = task_object.leaves if upstream else task_object.roots
216 for task in relatives:
217 if not isinstance(task, (BaseOperator, MappedOperator)):
218 raise AirflowException(
219 f"Relationships can only be set between Operators; received {task.__class__.__name__}"
220 )
221 task_list.append(task)
223 # relationships can only be set if the tasks share a single DAG. Tasks
224 # without a DAG are assigned to that DAG.
225 dags: set[DAG] = {task.dag for task in [*self.roots, *task_list] if task.has_dag() and task.dag}
227 if len(dags) > 1:
228 raise AirflowException(f"Tried to set relationships between tasks in more than one DAG: {dags}")
229 elif len(dags) == 1:
230 dag = dags.pop()
231 else:
232 raise AirflowException(
233 f"Tried to create relationships between tasks that don't have DAGs yet. "
234 f"Set the DAG for at least one task and try again: {[self, *task_list]}"
235 )
237 if not self.has_dag():
238 # If this task does not yet have a dag, add it to the same dag as the other task.
239 self.dag = dag
241 for task in task_list:
242 if dag and not task.has_dag():
243 # If the other task does not yet have a dag, add it to the same dag as this task and
244 dag.add_task(task)
245 if upstream:
246 task.downstream_task_ids.add(self.node_id)
247 self.upstream_task_ids.add(task.node_id)
248 if edge_modifier:
249 edge_modifier.add_edge_info(self.dag, task.node_id, self.node_id)
250 else:
251 self.downstream_task_ids.add(task.node_id)
252 task.upstream_task_ids.add(self.node_id)
253 if edge_modifier:
254 edge_modifier.add_edge_info(self.dag, self.node_id, task.node_id)
256 def set_downstream(
257 self,
258 task_or_task_list: DependencyMixin | Sequence[DependencyMixin],
259 edge_modifier: EdgeModifier | None = None,
260 ) -> None:
261 """Set a node (or nodes) to be directly downstream from the current node."""
262 self._set_relatives(task_or_task_list, upstream=False, edge_modifier=edge_modifier)
264 def set_upstream(
265 self,
266 task_or_task_list: DependencyMixin | Sequence[DependencyMixin],
267 edge_modifier: EdgeModifier | None = None,
268 ) -> None:
269 """Set a node (or nodes) to be directly upstream from the current node."""
270 self._set_relatives(task_or_task_list, upstream=True, edge_modifier=edge_modifier)
272 @property
273 def downstream_list(self) -> Iterable[Operator]:
274 """List of nodes directly downstream."""
275 if not self.dag:
276 raise AirflowException(f"Operator {self} has not been assigned to a DAG yet")
277 return [self.dag.get_task(tid) for tid in self.downstream_task_ids]
279 @property
280 def upstream_list(self) -> Iterable[Operator]:
281 """List of nodes directly upstream."""
282 if not self.dag:
283 raise AirflowException(f"Operator {self} has not been assigned to a DAG yet")
284 return [self.dag.get_task(tid) for tid in self.upstream_task_ids]
286 def get_direct_relative_ids(self, upstream: bool = False) -> set[str]:
287 """Get set of the direct relative ids to the current task, upstream or downstream."""
288 if upstream:
289 return self.upstream_task_ids
290 else:
291 return self.downstream_task_ids
293 def get_direct_relatives(self, upstream: bool = False) -> Iterable[DAGNode]:
294 """Get list of the direct relatives to the current task, upstream or downstream."""
295 if upstream:
296 return self.upstream_list
297 else:
298 return self.downstream_list
300 def serialize_for_task_group(self) -> tuple[DagAttributeTypes, Any]:
301 """Serialize a task group's content; used by TaskGroupSerialization."""
302 raise NotImplementedError()