1#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18from __future__ import annotations
19
20from typing import TYPE_CHECKING, cast
21
22from airflow.sdk.exceptions import AirflowException
23
24if TYPE_CHECKING:
25 from airflow.sdk.definitions._internal.abstractoperator import AbstractOperator
26 from airflow.sdk.definitions._internal.mixins import DependencyMixin
27 from airflow.sdk.definitions.xcom_arg import PlainXComArg
28
29
30class BaseSetupTeardownContext:
31 """
32 Context manager for setup/teardown tasks.
33
34 :meta private:
35 """
36
37 active: bool = False
38 context_map: dict[AbstractOperator | tuple[AbstractOperator], list[AbstractOperator]] = {}
39 _context_managed_setup_task: AbstractOperator | list[AbstractOperator] = []
40 _previous_context_managed_setup_task: list[AbstractOperator | list[AbstractOperator]] = []
41 _context_managed_teardown_task: AbstractOperator | list[AbstractOperator] = []
42 _previous_context_managed_teardown_task: list[AbstractOperator | list[AbstractOperator]] = []
43 _teardown_downstream_of_setup: AbstractOperator | list[AbstractOperator] = []
44 _previous_teardown_downstream_of_setup: list[AbstractOperator | list[AbstractOperator]] = []
45 _setup_upstream_of_teardown: AbstractOperator | list[AbstractOperator] = []
46 _previous_setup_upstream_of_teardown: list[AbstractOperator | list[AbstractOperator]] = []
47
48 @classmethod
49 def push_context_managed_setup_task(cls, task: AbstractOperator | list[AbstractOperator]):
50 setup_task = cls._context_managed_setup_task
51 if setup_task and setup_task != task:
52 cls._previous_context_managed_setup_task.append(cls._context_managed_setup_task)
53 cls._context_managed_setup_task = task
54
55 @classmethod
56 def push_context_managed_teardown_task(cls, task: AbstractOperator | list[AbstractOperator]):
57 teardown_task = cls._context_managed_teardown_task
58 if teardown_task and teardown_task != task:
59 cls._previous_context_managed_teardown_task.append(cls._context_managed_teardown_task)
60 cls._context_managed_teardown_task = task
61
62 @classmethod
63 def pop_context_managed_setup_task(cls) -> AbstractOperator | list[AbstractOperator]:
64 old_setup_task = cls._context_managed_setup_task
65 if cls._previous_context_managed_setup_task:
66 cls._context_managed_setup_task = cls._previous_context_managed_setup_task.pop()
67 setup_task = cls._context_managed_setup_task
68 if setup_task and old_setup_task:
69 cls.set_dependency(old_setup_task, setup_task, upstream=False)
70 else:
71 cls._context_managed_setup_task = []
72 return old_setup_task
73
74 @classmethod
75 def pop_context_managed_teardown_task(cls) -> AbstractOperator | list[AbstractOperator]:
76 old_teardown_task = cls._context_managed_teardown_task
77 if cls._previous_context_managed_teardown_task:
78 cls._context_managed_teardown_task = cls._previous_context_managed_teardown_task.pop()
79 teardown_task = cls._context_managed_teardown_task
80 if teardown_task and old_teardown_task:
81 cls.set_dependency(old_teardown_task, teardown_task)
82 else:
83 cls._context_managed_teardown_task = []
84 return old_teardown_task
85
86 @classmethod
87 def pop_teardown_downstream_of_setup(cls) -> AbstractOperator | list[AbstractOperator]:
88 old_teardown_task = cls._teardown_downstream_of_setup
89 if cls._previous_teardown_downstream_of_setup:
90 cls._teardown_downstream_of_setup = cls._previous_teardown_downstream_of_setup.pop()
91 teardown_task = cls._teardown_downstream_of_setup
92 if teardown_task and old_teardown_task:
93 cls.set_dependency(old_teardown_task, teardown_task)
94 else:
95 cls._teardown_downstream_of_setup = []
96 return old_teardown_task
97
98 @classmethod
99 def pop_setup_upstream_of_teardown(cls) -> AbstractOperator | list[AbstractOperator]:
100 old_setup_task = cls._setup_upstream_of_teardown
101 if cls._previous_setup_upstream_of_teardown:
102 cls._setup_upstream_of_teardown = cls._previous_setup_upstream_of_teardown.pop()
103 setup_task = cls._setup_upstream_of_teardown
104 if setup_task and old_setup_task:
105 cls.set_dependency(old_setup_task, setup_task, upstream=False)
106 else:
107 cls._setup_upstream_of_teardown = []
108 return old_setup_task
109
110 @classmethod
111 def set_dependency(
112 cls,
113 receiving_task: AbstractOperator | list[AbstractOperator],
114 new_task: AbstractOperator | list[AbstractOperator],
115 upstream=True,
116 ):
117 if isinstance(new_task, (list, tuple)):
118 for task in new_task:
119 cls._set_dependency(task, receiving_task, upstream)
120 else:
121 cls._set_dependency(new_task, receiving_task, upstream)
122
123 @staticmethod
124 def _set_dependency(task, receiving_task, upstream):
125 if upstream:
126 task.set_upstream(receiving_task)
127 else:
128 task.set_downstream(receiving_task)
129
130 @classmethod
131 def update_context_map(cls, task: DependencyMixin):
132 task_ = cast("AbstractOperator", task)
133 if task_.is_setup or task_.is_teardown:
134 return
135 ctx = cls.context_map
136
137 def _append_or_set_item(item):
138 if ctx.get(item) is None:
139 ctx[item] = [task_]
140 else:
141 ctx[item].append(task_)
142
143 if setup_task := cls._context_managed_setup_task:
144 if isinstance(setup_task, list):
145 _append_or_set_item(tuple(setup_task))
146 else:
147 _append_or_set_item(setup_task)
148 if teardown_task := cls._context_managed_teardown_task:
149 if isinstance(teardown_task, list):
150 _append_or_set_item(tuple(teardown_task))
151 else:
152 _append_or_set_item(teardown_task)
153
154 @classmethod
155 def push_setup_teardown_task(cls, operator: AbstractOperator | list[AbstractOperator]):
156 if isinstance(operator, list):
157 if operator[0].is_teardown:
158 cls._push_tasks(operator)
159 elif operator[0].is_setup:
160 cls._push_tasks(operator, setup=True)
161 elif operator.is_teardown:
162 cls._push_tasks(operator)
163 elif operator.is_setup:
164 cls._push_tasks(operator, setup=True)
165 cls.active = True
166
167 @classmethod
168 def _push_tasks(cls, operator: AbstractOperator | list[AbstractOperator], setup: bool = False):
169 if isinstance(operator, list):
170 if any(task.is_setup != operator[0].is_setup for task in operator):
171 cls.error("All tasks in the list must be either setup or teardown tasks")
172 if setup:
173 cls.push_context_managed_setup_task(operator)
174 # workout the teardown
175 cls._update_teardown_downstream(operator)
176 else:
177 cls.push_context_managed_teardown_task(operator)
178 # workout the setups
179 cls._update_setup_upstream(operator)
180
181 @classmethod
182 def _update_teardown_downstream(cls, operator: AbstractOperator | list[AbstractOperator]):
183 """
184 Recursively go through the tasks downstream of the setup in the context manager.
185
186 If found, update the _teardown_downstream_of_setup accordingly.
187 """
188 operator = operator[0] if isinstance(operator, list) else operator
189
190 def _get_teardowns(tasks):
191 teardowns = [i for i in tasks if i.is_teardown]
192 if not teardowns:
193 all_lists = [task.downstream_list + task.upstream_list for task in tasks]
194 new_list = [
195 x
196 for sublist in all_lists
197 for x in sublist
198 if (isinstance(operator, list) and x in operator) or x != operator
199 ]
200 if not new_list:
201 return []
202 return _get_teardowns(new_list)
203 return teardowns
204
205 teardowns = _get_teardowns(operator.downstream_list)
206 teardown_task = cls._teardown_downstream_of_setup
207 if teardown_task and teardown_task != teardowns:
208 cls._previous_teardown_downstream_of_setup.append(cls._teardown_downstream_of_setup)
209 cls._teardown_downstream_of_setup = teardowns
210
211 @classmethod
212 def _update_setup_upstream(cls, operator: AbstractOperator | list[AbstractOperator]):
213 """
214 Recursively go through the tasks upstream of the teardown task in the context manager.
215
216 If found, updates the _setup_upstream_of_teardown accordingly.
217 """
218 operator = operator[0] if isinstance(operator, list) else operator
219
220 def _get_setups(tasks):
221 setups = [i for i in tasks if i.is_setup]
222 if not setups:
223 all_lists = [task.downstream_list + task.upstream_list for task in tasks]
224 new_list = [
225 x
226 for sublist in all_lists
227 for x in sublist
228 if (isinstance(operator, list) and x in operator) or x != operator
229 ]
230 if not new_list:
231 return []
232 return _get_setups(new_list)
233 return setups
234
235 setups = _get_setups(operator.upstream_list)
236 setup_task = cls._setup_upstream_of_teardown
237 if setup_task and setup_task != setups:
238 cls._previous_setup_upstream_of_teardown.append(cls._setup_upstream_of_teardown)
239 cls._setup_upstream_of_teardown = setups
240
241 @classmethod
242 def set_teardown_task_as_leaves(cls, leaves):
243 teardown_task = cls._teardown_downstream_of_setup
244 if cls._context_managed_teardown_task:
245 cls.set_dependency(cls._context_managed_teardown_task, teardown_task)
246 else:
247 cls.set_dependency(leaves, teardown_task)
248
249 @classmethod
250 def set_setup_task_as_roots(cls, roots):
251 setup_task = cls._setup_upstream_of_teardown
252 if cls._context_managed_setup_task:
253 cls.set_dependency(cls._context_managed_setup_task, setup_task, upstream=False)
254 else:
255 cls.set_dependency(roots, setup_task, upstream=False)
256
257 @classmethod
258 def set_work_task_roots_and_leaves(cls):
259 """Set the work task roots and leaves."""
260 if setup_task := cls._context_managed_setup_task:
261 if isinstance(setup_task, list):
262 setup_task = tuple(setup_task)
263 tasks_in_context = [
264 x for x in cls.context_map.get(setup_task, []) if not x.is_teardown and not x.is_setup
265 ]
266 if tasks_in_context:
267 roots = [task for task in tasks_in_context if not task.upstream_list]
268 if not roots:
269 setup_task >> tasks_in_context[0]
270 else:
271 cls.set_dependency(roots, setup_task, upstream=False)
272 leaves = [task for task in tasks_in_context if not task.downstream_list]
273 if not leaves:
274 leaves = tasks_in_context[-1]
275 cls.set_teardown_task_as_leaves(leaves)
276
277 if teardown_task := cls._context_managed_teardown_task:
278 if isinstance(teardown_task, list):
279 teardown_task = tuple(teardown_task)
280 tasks_in_context = [
281 x for x in cls.context_map.get(teardown_task, []) if not x.is_teardown and not x.is_setup
282 ]
283 if tasks_in_context:
284 leaves = [task for task in tasks_in_context if not task.downstream_list]
285 if not leaves:
286 teardown_task << tasks_in_context[-1]
287 else:
288 cls.set_dependency(leaves, teardown_task)
289 roots = [task for task in tasks_in_context if not task.upstream_list]
290 if not roots:
291 roots = tasks_in_context[0]
292 cls.set_setup_task_as_roots(roots)
293 cls.set_setup_teardown_relationships()
294 cls.active = False
295
296 @classmethod
297 def set_setup_teardown_relationships(cls):
298 """
299 Set relationship between setup to setup and teardown to teardown.
300
301 code:: python
302 with setuptask >> teardowntask:
303 with setuptask2 >> teardowntask2:
304 ...
305
306 We set setuptask >> setuptask2, teardowntask >> teardowntask2
307 """
308 setup_task = cls.pop_context_managed_setup_task()
309 teardown_task = cls.pop_context_managed_teardown_task()
310 if isinstance(setup_task, list):
311 setup_task = tuple(setup_task)
312 if isinstance(teardown_task, list):
313 teardown_task = tuple(teardown_task)
314 cls.pop_teardown_downstream_of_setup()
315 cls.pop_setup_upstream_of_teardown()
316 cls.context_map.pop(setup_task, None)
317 cls.context_map.pop(teardown_task, None)
318
319 @classmethod
320 def error(cls, message: str):
321 cls.active = False
322 cls.context_map.clear()
323 cls._context_managed_setup_task = []
324 cls._context_managed_teardown_task = []
325 cls._previous_context_managed_setup_task = []
326 cls._previous_context_managed_teardown_task = []
327 raise ValueError(message)
328
329
330class SetupTeardownContext(BaseSetupTeardownContext):
331 """Context manager for setup and teardown tasks."""
332
333 @staticmethod
334 def add_task(task: AbstractOperator | PlainXComArg):
335 """Add task to context manager."""
336 from airflow.sdk.definitions.xcom_arg import PlainXComArg
337
338 if not SetupTeardownContext.active:
339 raise AirflowException("Cannot add task to context outside the context manager.")
340 if isinstance(task, PlainXComArg):
341 task = task.operator
342 SetupTeardownContext.update_context_map(task)