1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements. See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership. The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License. You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied. See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18from __future__ import annotations
19
20from typing import TYPE_CHECKING, cast
21
22from airflow.exceptions import AirflowException
23
24if TYPE_CHECKING:
25 from airflow.models.abstractoperator import AbstractOperator
26 from airflow.models.taskmixin import DependencyMixin
27 from airflow.models.xcom_arg import PlainXComArg
28
29
30class BaseSetupTeardownContext:
31 """Context manager for setup/teardown tasks.
32
33 :meta private:
34 """
35
36 active: bool = False
37 context_map: dict[AbstractOperator | tuple[AbstractOperator], list[AbstractOperator]] = {}
38 _context_managed_setup_task: AbstractOperator | list[AbstractOperator] = []
39 _previous_context_managed_setup_task: list[AbstractOperator | list[AbstractOperator]] = []
40 _context_managed_teardown_task: AbstractOperator | list[AbstractOperator] = []
41 _previous_context_managed_teardown_task: list[AbstractOperator | list[AbstractOperator]] = []
42 _teardown_downstream_of_setup: AbstractOperator | list[AbstractOperator] = []
43 _previous_teardown_downstream_of_setup: list[AbstractOperator | list[AbstractOperator]] = []
44 _setup_upstream_of_teardown: AbstractOperator | list[AbstractOperator] = []
45 _previous_setup_upstream_of_teardown: list[AbstractOperator | list[AbstractOperator]] = []
46
47 @classmethod
48 def push_context_managed_setup_task(cls, task: AbstractOperator | list[AbstractOperator]):
49 setup_task = cls._context_managed_setup_task
50 if setup_task and setup_task != task:
51 cls._previous_context_managed_setup_task.append(cls._context_managed_setup_task)
52 cls._context_managed_setup_task = task
53
54 @classmethod
55 def push_context_managed_teardown_task(cls, task: AbstractOperator | list[AbstractOperator]):
56 teardown_task = cls._context_managed_teardown_task
57 if teardown_task and teardown_task != task:
58 cls._previous_context_managed_teardown_task.append(cls._context_managed_teardown_task)
59 cls._context_managed_teardown_task = task
60
61 @classmethod
62 def pop_context_managed_setup_task(cls) -> AbstractOperator | list[AbstractOperator]:
63 old_setup_task = cls._context_managed_setup_task
64 if cls._previous_context_managed_setup_task:
65 cls._context_managed_setup_task = cls._previous_context_managed_setup_task.pop()
66 setup_task = cls._context_managed_setup_task
67 if setup_task and old_setup_task:
68 cls.set_dependency(old_setup_task, setup_task, upstream=False)
69 else:
70 cls._context_managed_setup_task = []
71 return old_setup_task
72
73 @classmethod
74 def pop_context_managed_teardown_task(cls) -> AbstractOperator | list[AbstractOperator]:
75 old_teardown_task = cls._context_managed_teardown_task
76 if cls._previous_context_managed_teardown_task:
77 cls._context_managed_teardown_task = cls._previous_context_managed_teardown_task.pop()
78 teardown_task = cls._context_managed_teardown_task
79 if teardown_task and old_teardown_task:
80 cls.set_dependency(old_teardown_task, teardown_task)
81 else:
82 cls._context_managed_teardown_task = []
83 return old_teardown_task
84
85 @classmethod
86 def pop_teardown_downstream_of_setup(cls) -> AbstractOperator | list[AbstractOperator]:
87 old_teardown_task = cls._teardown_downstream_of_setup
88 if cls._previous_teardown_downstream_of_setup:
89 cls._teardown_downstream_of_setup = cls._previous_teardown_downstream_of_setup.pop()
90 teardown_task = cls._teardown_downstream_of_setup
91 if teardown_task and old_teardown_task:
92 cls.set_dependency(old_teardown_task, teardown_task)
93 else:
94 cls._teardown_downstream_of_setup = []
95 return old_teardown_task
96
97 @classmethod
98 def pop_setup_upstream_of_teardown(cls) -> AbstractOperator | list[AbstractOperator]:
99 old_setup_task = cls._setup_upstream_of_teardown
100 if cls._previous_setup_upstream_of_teardown:
101 cls._setup_upstream_of_teardown = cls._previous_setup_upstream_of_teardown.pop()
102 setup_task = cls._setup_upstream_of_teardown
103 if setup_task and old_setup_task:
104 cls.set_dependency(old_setup_task, setup_task, upstream=False)
105 else:
106 cls._setup_upstream_of_teardown = []
107 return old_setup_task
108
109 @classmethod
110 def set_dependency(
111 cls,
112 receiving_task: AbstractOperator | list[AbstractOperator],
113 new_task: AbstractOperator | list[AbstractOperator],
114 upstream=True,
115 ):
116 if isinstance(new_task, (list, tuple)):
117 for task in new_task:
118 cls._set_dependency(task, receiving_task, upstream)
119 else:
120 cls._set_dependency(new_task, receiving_task, upstream)
121
122 @staticmethod
123 def _set_dependency(task, receiving_task, upstream):
124 if upstream:
125 task.set_upstream(receiving_task)
126 else:
127 task.set_downstream(receiving_task)
128
129 @classmethod
130 def update_context_map(cls, task: DependencyMixin):
131 from airflow.models.abstractoperator import AbstractOperator
132
133 task_ = cast(AbstractOperator, task)
134 if task_.is_setup or task_.is_teardown:
135 return
136 ctx = cls.context_map
137
138 def _append_or_set_item(item):
139 if ctx.get(item) is None:
140 ctx[item] = [task_]
141 else:
142 ctx[item].append(task_)
143
144 if setup_task := cls._context_managed_setup_task:
145 if isinstance(setup_task, list):
146 _append_or_set_item(tuple(setup_task))
147 else:
148 _append_or_set_item(setup_task)
149 if teardown_task := cls._context_managed_teardown_task:
150 if isinstance(teardown_task, list):
151 _append_or_set_item(tuple(teardown_task))
152 else:
153 _append_or_set_item(teardown_task)
154
155 @classmethod
156 def push_setup_teardown_task(cls, operator: AbstractOperator | list[AbstractOperator]):
157 if isinstance(operator, list):
158 if operator[0].is_teardown:
159 cls._push_tasks(operator)
160 elif operator[0].is_setup:
161 cls._push_tasks(operator, setup=True)
162 elif operator.is_teardown:
163 cls._push_tasks(operator)
164 elif operator.is_setup:
165 cls._push_tasks(operator, setup=True)
166 cls.active = True
167
168 @classmethod
169 def _push_tasks(cls, operator: AbstractOperator | list[AbstractOperator], setup: bool = False):
170 if isinstance(operator, list):
171 if any(task.is_setup != operator[0].is_setup for task in operator):
172 cls.error("All tasks in the list must be either setup or teardown tasks")
173 if setup:
174 cls.push_context_managed_setup_task(operator)
175 # workout the teardown
176 cls._update_teardown_downstream(operator)
177 else:
178 cls.push_context_managed_teardown_task(operator)
179 # workout the setups
180 cls._update_setup_upstream(operator)
181
182 @classmethod
183 def _update_teardown_downstream(cls, operator: AbstractOperator | list[AbstractOperator]):
184 """
185 Recursively go through the tasks downstream of the setup in the context manager.
186
187 If found, update the _teardown_downstream_of_setup accordingly.
188 """
189 operator = operator[0] if isinstance(operator, list) else operator
190
191 def _get_teardowns(tasks):
192 teardowns = [i for i in tasks if i.is_teardown]
193 if not teardowns:
194 all_lists = [task.downstream_list + task.upstream_list for task in tasks]
195 new_list = [
196 x
197 for sublist in all_lists
198 for x in sublist
199 if (isinstance(operator, list) and x in operator) or x != operator
200 ]
201 if not new_list:
202 return []
203 return _get_teardowns(new_list)
204 return teardowns
205
206 teardowns = _get_teardowns(operator.downstream_list)
207 teardown_task = cls._teardown_downstream_of_setup
208 if teardown_task and teardown_task != teardowns:
209 cls._previous_teardown_downstream_of_setup.append(cls._teardown_downstream_of_setup)
210 cls._teardown_downstream_of_setup = teardowns
211
212 @classmethod
213 def _update_setup_upstream(cls, operator: AbstractOperator | list[AbstractOperator]):
214 """
215 Recursively go through the tasks upstream of the teardown task in the context manager.
216
217 If found, updates the _setup_upstream_of_teardown accordingly.
218 """
219 operator = operator[0] if isinstance(operator, list) else operator
220
221 def _get_setups(tasks):
222 setups = [i for i in tasks if i.is_setup]
223 if not setups:
224 all_lists = [task.downstream_list + task.upstream_list for task in tasks]
225 new_list = [
226 x
227 for sublist in all_lists
228 for x in sublist
229 if (isinstance(operator, list) and x in operator) or x != operator
230 ]
231 if not new_list:
232 return []
233 return _get_setups(new_list)
234 return setups
235
236 setups = _get_setups(operator.upstream_list)
237 setup_task = cls._setup_upstream_of_teardown
238 if setup_task and setup_task != setups:
239 cls._previous_setup_upstream_of_teardown.append(cls._setup_upstream_of_teardown)
240 cls._setup_upstream_of_teardown = setups
241
242 @classmethod
243 def set_teardown_task_as_leaves(cls, leaves):
244 teardown_task = cls._teardown_downstream_of_setup
245 if cls._context_managed_teardown_task:
246 cls.set_dependency(cls._context_managed_teardown_task, teardown_task)
247 else:
248 cls.set_dependency(leaves, teardown_task)
249
250 @classmethod
251 def set_setup_task_as_roots(cls, roots):
252 setup_task = cls._setup_upstream_of_teardown
253 if cls._context_managed_setup_task:
254 cls.set_dependency(cls._context_managed_setup_task, setup_task, upstream=False)
255 else:
256 cls.set_dependency(roots, setup_task, upstream=False)
257
258 @classmethod
259 def set_work_task_roots_and_leaves(cls):
260 """Set the work task roots and leaves."""
261 if setup_task := cls._context_managed_setup_task:
262 if isinstance(setup_task, list):
263 setup_task = tuple(setup_task)
264 tasks_in_context = [
265 x for x in cls.context_map.get(setup_task, []) if not x.is_teardown and not x.is_setup
266 ]
267 if tasks_in_context:
268 roots = [task for task in tasks_in_context if not task.upstream_list]
269 if not roots:
270 setup_task >> tasks_in_context[0]
271 else:
272 cls.set_dependency(roots, setup_task, upstream=False)
273 leaves = [task for task in tasks_in_context if not task.downstream_list]
274 if not leaves:
275 leaves = tasks_in_context[-1]
276 cls.set_teardown_task_as_leaves(leaves)
277
278 if teardown_task := cls._context_managed_teardown_task:
279 if isinstance(teardown_task, list):
280 teardown_task = tuple(teardown_task)
281 tasks_in_context = [
282 x for x in cls.context_map.get(teardown_task, []) if not x.is_teardown and not x.is_setup
283 ]
284 if tasks_in_context:
285 leaves = [task for task in tasks_in_context if not task.downstream_list]
286 if not leaves:
287 teardown_task << tasks_in_context[-1]
288 else:
289 cls.set_dependency(leaves, teardown_task)
290 roots = [task for task in tasks_in_context if not task.upstream_list]
291 if not roots:
292 roots = tasks_in_context[0]
293 cls.set_setup_task_as_roots(roots)
294 cls.set_setup_teardown_relationships()
295 cls.active = False
296
297 @classmethod
298 def set_setup_teardown_relationships(cls):
299 """
300 Set relationship between setup to setup and teardown to teardown.
301
302 code:: python
303 with setuptask >> teardowntask:
304 with setuptask2 >> teardowntask2:
305 ...
306
307 We set setuptask >> setuptask2, teardowntask >> teardowntask2
308 """
309 setup_task = cls.pop_context_managed_setup_task()
310 teardown_task = cls.pop_context_managed_teardown_task()
311 if isinstance(setup_task, list):
312 setup_task = tuple(setup_task)
313 if isinstance(teardown_task, list):
314 teardown_task = tuple(teardown_task)
315 cls.pop_teardown_downstream_of_setup()
316 cls.pop_setup_upstream_of_teardown()
317 cls.context_map.pop(setup_task, None)
318 cls.context_map.pop(teardown_task, None)
319
320 @classmethod
321 def error(cls, message: str):
322 cls.active = False
323 cls.context_map.clear()
324 cls._context_managed_setup_task = []
325 cls._context_managed_teardown_task = []
326 cls._previous_context_managed_setup_task = []
327 cls._previous_context_managed_teardown_task = []
328 raise ValueError(message)
329
330
331class SetupTeardownContext(BaseSetupTeardownContext):
332 """Context manager for setup and teardown tasks."""
333
334 @staticmethod
335 def add_task(task: AbstractOperator | PlainXComArg):
336 """Add task to context manager."""
337 from airflow.models.xcom_arg import PlainXComArg
338
339 if not SetupTeardownContext.active:
340 raise AirflowException("Cannot add task to context outside the context manager.")
341 if isinstance(task, PlainXComArg):
342 task = task.operator
343 SetupTeardownContext.update_context_map(task)