1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements. See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership. The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License. You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied. See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18from __future__ import annotations
19
20from abc import abstractmethod
21from collections.abc import Iterable, Sequence
22from typing import TYPE_CHECKING, Any
23
24if TYPE_CHECKING:
25 from typing import TypeAlias
26
27 from airflow.sdk.bases.operator import BaseOperator
28 from airflow.sdk.definitions.context import Context
29 from airflow.sdk.definitions.edges import EdgeModifier
30 from airflow.sdk.definitions.mappedoperator import MappedOperator
31
32 Operator: TypeAlias = BaseOperator | MappedOperator
33
34# TODO: Should this all just live on DAGNode?
35
36
37class DependencyMixin:
38 """Mixing implementing common dependency setting methods like >> and <<."""
39
40 @property
41 def roots(self) -> Iterable[DependencyMixin]:
42 """
43 List of root nodes -- ones with no upstream dependencies.
44
45 a.k.a. the "start" of this sub-graph
46 """
47 raise NotImplementedError()
48
49 @property
50 def leaves(self) -> Iterable[DependencyMixin]:
51 """
52 List of leaf nodes -- ones with only upstream dependencies.
53
54 a.k.a. the "end" of this sub-graph
55 """
56 raise NotImplementedError()
57
58 @abstractmethod
59 def set_upstream(
60 self, other: DependencyMixin | Sequence[DependencyMixin], edge_modifier: EdgeModifier | None = None
61 ):
62 """Set a task or a task list to be directly upstream from the current task."""
63 raise NotImplementedError()
64
65 @abstractmethod
66 def set_downstream(
67 self, other: DependencyMixin | Sequence[DependencyMixin], edge_modifier: EdgeModifier | None = None
68 ):
69 """Set a task or a task list to be directly downstream from the current task."""
70 raise NotImplementedError()
71
72 def as_setup(self) -> DependencyMixin:
73 """Mark a task as setup task."""
74 raise NotImplementedError()
75
76 def as_teardown(
77 self,
78 *,
79 setups: BaseOperator | Iterable[BaseOperator] | None = None,
80 on_failure_fail_dagrun: bool | None = None,
81 ) -> DependencyMixin:
82 """Mark a task as teardown and set its setups as direct relatives."""
83 raise NotImplementedError()
84
85 def update_relative(
86 self, other: DependencyMixin, upstream: bool = True, edge_modifier: EdgeModifier | None = None
87 ) -> None:
88 """
89 Update relationship information about another TaskMixin. Default is no-op.
90
91 Override if necessary.
92 """
93
94 def __lshift__(self, other: DependencyMixin | Sequence[DependencyMixin]):
95 """Implement Task << Task."""
96 self.set_upstream(other)
97 return other
98
99 def __rshift__(self, other: DependencyMixin | Sequence[DependencyMixin]):
100 """Implement Task >> Task."""
101 self.set_downstream(other)
102 return other
103
104 def __rrshift__(self, other: DependencyMixin | Sequence[DependencyMixin]):
105 """Implement Task >> [Task] because list don't have __rshift__ operators."""
106 self.__lshift__(other)
107 return self
108
109 def __rlshift__(self, other: DependencyMixin | Sequence[DependencyMixin]):
110 """Implement Task << [Task] because list don't have __lshift__ operators."""
111 self.__rshift__(other)
112 return self
113
114 @classmethod
115 def _iter_references(cls, obj: Any) -> Iterable[tuple[DependencyMixin, str]]:
116 from airflow.sdk.definitions._internal.abstractoperator import AbstractOperator
117
118 if isinstance(obj, AbstractOperator):
119 yield obj, "operator"
120 elif isinstance(obj, ResolveMixin):
121 yield from obj.iter_references()
122 elif isinstance(obj, Sequence):
123 for o in obj:
124 yield from cls._iter_references(o)
125
126
127class ResolveMixin:
128 """A runtime-resolved value."""
129
130 def iter_references(self) -> Iterable[tuple[Operator, str]]:
131 """
132 Find underlying XCom references this contains.
133
134 This is used by the Dag parser to recursively find task dependencies.
135
136 :meta private:
137 """
138 raise NotImplementedError
139
140 def resolve(self, context: Context) -> Any:
141 """
142 Resolve this value for runtime.
143
144 :meta private:
145 """
146 raise NotImplementedError