1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements. See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership. The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License. You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied. See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18from __future__ import annotations
19
20from abc import abstractmethod
21from collections.abc import Iterable, Sequence
22from typing import TYPE_CHECKING, Any
23
24if TYPE_CHECKING:
25 from typing import TypeAlias
26
27 from airflow.sdk.bases.operator import BaseOperator
28 from airflow.sdk.definitions.context import Context
29 from airflow.sdk.definitions.edges import EdgeModifier
30 from airflow.sdk.definitions.mappedoperator import MappedOperator
31
32 Operator: TypeAlias = BaseOperator | MappedOperator
33
34
35class DependencyMixin:
36 """Mixing implementing common dependency setting methods like >> and <<."""
37
38 @property
39 def roots(self) -> Iterable[DependencyMixin]:
40 """
41 List of root nodes -- ones with no upstream dependencies.
42
43 a.k.a. the "start" of this sub-graph
44 """
45 raise NotImplementedError()
46
47 @property
48 def leaves(self) -> Iterable[DependencyMixin]:
49 """
50 List of leaf nodes -- ones with only upstream dependencies.
51
52 a.k.a. the "end" of this sub-graph
53 """
54 raise NotImplementedError()
55
56 @abstractmethod
57 def set_upstream(
58 self, other: DependencyMixin | Sequence[DependencyMixin], edge_modifier: EdgeModifier | None = None
59 ):
60 """Set a task or a task list to be directly upstream from the current task."""
61 raise NotImplementedError()
62
63 @abstractmethod
64 def set_downstream(
65 self, other: DependencyMixin | Sequence[DependencyMixin], edge_modifier: EdgeModifier | None = None
66 ):
67 """Set a task or a task list to be directly downstream from the current task."""
68 raise NotImplementedError()
69
70 def as_setup(self) -> DependencyMixin:
71 """Mark a task as setup task."""
72 raise NotImplementedError()
73
74 def as_teardown(
75 self,
76 *,
77 setups: BaseOperator | Iterable[BaseOperator] | None = None,
78 on_failure_fail_dagrun: bool | None = None,
79 ) -> DependencyMixin:
80 """Mark a task as teardown and set its setups as direct relatives."""
81 raise NotImplementedError()
82
83 def update_relative(
84 self, other: DependencyMixin, upstream: bool = True, edge_modifier: EdgeModifier | None = None
85 ) -> None:
86 """
87 Update relationship information about another TaskMixin. Default is no-op.
88
89 Override if necessary.
90 """
91
92 def __lshift__(self, other: DependencyMixin | Sequence[DependencyMixin]):
93 """Implement Task << Task."""
94 self.set_upstream(other)
95 return other
96
97 def __rshift__(self, other: DependencyMixin | Sequence[DependencyMixin]):
98 """Implement Task >> Task."""
99 self.set_downstream(other)
100 return other
101
102 def __rrshift__(self, other: DependencyMixin | Sequence[DependencyMixin]):
103 """Implement Task >> [Task] because list don't have __rshift__ operators."""
104 self.__lshift__(other)
105 return self
106
107 def __rlshift__(self, other: DependencyMixin | Sequence[DependencyMixin]):
108 """Implement Task << [Task] because list don't have __lshift__ operators."""
109 self.__rshift__(other)
110 return self
111
112 @classmethod
113 def _iter_references(cls, obj: Any) -> Iterable[tuple[DependencyMixin, str]]:
114 from airflow.sdk.definitions._internal.abstractoperator import AbstractOperator
115
116 if isinstance(obj, AbstractOperator):
117 yield obj, "operator"
118 elif isinstance(obj, ResolveMixin):
119 yield from obj.iter_references()
120 elif isinstance(obj, Sequence):
121 for o in obj:
122 yield from cls._iter_references(o)
123
124
125class ResolveMixin:
126 """A runtime-resolved value."""
127
128 def iter_references(self) -> Iterable[tuple[Operator, str]]:
129 """
130 Find underlying XCom references this contains.
131
132 This is used by the Dag parser to recursively find task dependencies.
133
134 :meta private:
135 """
136 raise NotImplementedError
137
138 def resolve(self, context: Context) -> Any:
139 """
140 Resolve this value for runtime.
141
142 :meta private:
143 """
144 raise NotImplementedError