1#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18"""A collection of closely related tasks on the same Dag that should be grouped together visually."""
19
20from __future__ import annotations
21
22import copy
23import re
24import weakref
25from collections.abc import Generator, Iterator, Sequence
26from typing import TYPE_CHECKING, Any
27
28import attrs
29
30from airflow.sdk import TriggerRule
31from airflow.sdk.definitions._internal.node import DAGNode, validate_group_key
32from airflow.sdk.exceptions import (
33 AirflowDagCycleException,
34 DuplicateTaskIdFound,
35 TaskAlreadyInTaskGroup,
36)
37
38if TYPE_CHECKING:
39 from airflow.sdk.bases.operator import BaseOperator
40 from airflow.sdk.definitions._internal.abstractoperator import AbstractOperator
41 from airflow.sdk.definitions._internal.expandinput import DictOfListsExpandInput, ListOfDictsExpandInput
42 from airflow.sdk.definitions._internal.mixins import DependencyMixin
43 from airflow.sdk.definitions.dag import DAG
44 from airflow.sdk.definitions.edges import EdgeModifier
45 from airflow.sdk.types import Operator
46 from airflow.serialization.enums import DagAttributeTypes
47
48
49def _default_parent_group() -> TaskGroup | None:
50 from airflow.sdk.definitions._internal.contextmanager import TaskGroupContext
51
52 return TaskGroupContext.get_current()
53
54
55def _parent_used_group_ids(tg: TaskGroup) -> set:
56 if tg.parent_group:
57 return tg.parent_group.used_group_ids
58 return set()
59
60
61# This could be achieved with `@dag.default` and make this a method, but for some unknown reason when we do
62# that it makes Mypy (1.9.0 and 1.13.0 tested) seem to entirely loose track that this is an Attrs class. So
63# we've gone with this and moved on with our lives, mypy is to much of a dark beast to battle over this.
64def _default_dag(instance: TaskGroup):
65 from airflow.sdk.definitions._internal.contextmanager import DagContext
66
67 if (pg := instance.parent_group) is not None:
68 return pg.dag
69 return DagContext.get_current()
70
71
72# Mypy does not like a lambda for some reason. An explicit annotated function makes it happy.
73def _validate_group_id(instance, attribute, value: str) -> None:
74 validate_group_key(value)
75
76
77@attrs.define(repr=False)
78class TaskGroup(DAGNode):
79 """
80 A collection of tasks.
81
82 When set_downstream() or set_upstream() are called on the TaskGroup, it is applied across
83 all tasks within the group if necessary.
84
85 :param group_id: a unique, meaningful id for the TaskGroup. group_id must not conflict
86 with group_id of TaskGroup or task_id of tasks in the Dag. Root TaskGroup has group_id
87 set to None.
88 :param prefix_group_id: If set to True, child task_id and group_id will be prefixed with
89 this TaskGroup's group_id. If set to False, child task_id and group_id are not prefixed.
90 Default is True.
91 :param parent_group: The parent TaskGroup of this TaskGroup. parent_group is set to None
92 for the root TaskGroup.
93 :param dag: The Dag that this TaskGroup belongs to.
94 :param default_args: A dictionary of default parameters to be used
95 as constructor keyword parameters when initialising operators,
96 will override default_args defined in the Dag level.
97 Note that operators have the same hook, and precede those defined
98 here, meaning that if your dict contains `'depends_on_past': True`
99 here and `'depends_on_past': False` in the operator's call
100 `default_args`, the actual value will be `False`.
101 :param tooltip: The tooltip of the TaskGroup node when displayed in the UI
102 :param ui_color: The fill color of the TaskGroup node when displayed in the UI
103 :param ui_fgcolor: The label color of the TaskGroup node when displayed in the UI
104 :param add_suffix_on_collision: If this task group name already exists,
105 automatically add `__1` etc suffixes
106 :param group_display_name: If set, this will be the display name for the TaskGroup node in the UI.
107 """
108
109 _group_id: str | None = attrs.field(
110 validator=attrs.validators.optional(_validate_group_id),
111 # This is the default behaviour for attrs, but by specifying this it makes IDEs happier
112 alias="group_id",
113 )
114 group_display_name: str = attrs.field(default="", validator=attrs.validators.instance_of(str))
115 prefix_group_id: bool = attrs.field(default=True)
116 parent_group: TaskGroup | None = attrs.field(factory=_default_parent_group)
117 dag: DAG = attrs.field(default=attrs.Factory(_default_dag, takes_self=True))
118 default_args: dict[str, Any] = attrs.field(factory=dict, converter=copy.deepcopy)
119 tooltip: str = attrs.field(default="", validator=attrs.validators.instance_of(str))
120 children: dict[str, DAGNode] = attrs.field(factory=dict, init=False)
121
122 upstream_group_ids: set[str | None] = attrs.field(factory=set, init=False)
123 downstream_group_ids: set[str | None] = attrs.field(factory=set, init=False)
124 upstream_task_ids: set[str] = attrs.field(factory=set, init=False)
125 downstream_task_ids: set[str] = attrs.field(factory=set, init=False)
126
127 used_group_ids: set[str] = attrs.field(
128 default=attrs.Factory(_parent_used_group_ids, takes_self=True),
129 init=False,
130 on_setattr=attrs.setters.frozen,
131 )
132
133 ui_color: str = attrs.field(default="CornflowerBlue", validator=attrs.validators.instance_of(str))
134 ui_fgcolor: str = attrs.field(default="#000", validator=attrs.validators.instance_of(str))
135
136 add_suffix_on_collision: bool = False
137
138 @dag.validator
139 def _validate_dag(self, _attr, dag):
140 if not dag:
141 raise RuntimeError("TaskGroup can only be used inside a dag")
142
143 def __attrs_post_init__(self):
144 # TODO: If attrs supported init only args we could use that here
145 # https://github.com/python-attrs/attrs/issues/342
146 self._check_for_group_id_collisions(self.add_suffix_on_collision)
147
148 if self._group_id and not self.parent_group and self.dag:
149 # Support `tg = TaskGroup(x, dag=dag)`
150 self.parent_group = self.dag.task_group
151
152 if self.parent_group:
153 self.parent_group.add(self)
154 if self.parent_group.default_args:
155 self.default_args = {
156 **self.parent_group.default_args,
157 **self.default_args,
158 }
159
160 if self._group_id:
161 self.used_group_ids.add(self.group_id)
162 self.used_group_ids.add(self.downstream_join_id)
163 self.used_group_ids.add(self.upstream_join_id)
164
165 def _check_for_group_id_collisions(self, add_suffix_on_collision: bool):
166 if self._group_id is None:
167 return
168 # if given group_id already used assign suffix by incrementing largest used suffix integer
169 # Example : task_group ==> task_group__1 -> task_group__2 -> task_group__3
170 if self.group_id in self.used_group_ids:
171 if not add_suffix_on_collision:
172 raise DuplicateTaskIdFound(f"group_id '{self._group_id}' has already been added to the DAG")
173 base = re.split(r"__\d+$", self._group_id)[0]
174 suffixes = sorted(
175 int(re.split(r"^.+__", used_group_id)[1])
176 for used_group_id in self.used_group_ids
177 if used_group_id is not None and re.match(rf"^{base}__\d+$", used_group_id)
178 )
179 if not suffixes:
180 self._group_id += "__1"
181 else:
182 self._group_id = f"{base}__{suffixes[-1] + 1}"
183
184 @classmethod
185 def create_root(cls, dag: DAG) -> TaskGroup:
186 """Create a root TaskGroup with no group_id or parent."""
187 return cls(group_id=None, dag=dag, parent_group=None)
188
189 @property
190 def node_id(self):
191 return self.group_id
192
193 @property
194 def is_root(self) -> bool:
195 """Returns True if this TaskGroup is the root TaskGroup. Otherwise False."""
196 return not self._group_id
197
198 @property
199 def task_group(self) -> TaskGroup | None:
200 return self.parent_group
201
202 @task_group.setter
203 def task_group(self, value: TaskGroup | None):
204 self.parent_group = value
205
206 def __iter__(self):
207 for child in self.children.values():
208 yield from self._iter_child(child)
209
210 @staticmethod
211 def _iter_child(child):
212 """Iterate over the children of this TaskGroup."""
213 if isinstance(child, TaskGroup):
214 yield from child
215 else:
216 yield child
217
218 def add(self, task: DAGNode) -> DAGNode:
219 """
220 Add a task or TaskGroup to this TaskGroup.
221
222 :meta private:
223 """
224 from airflow.sdk.definitions._internal.abstractoperator import AbstractOperator
225 from airflow.sdk.definitions._internal.contextmanager import TaskGroupContext
226
227 if TaskGroupContext.active:
228 if task.task_group and task.task_group != self:
229 task.task_group.children.pop(task.node_id, None)
230 task.task_group = self
231 existing_tg = task.task_group
232 if isinstance(task, AbstractOperator) and existing_tg is not None and existing_tg != self:
233 raise TaskAlreadyInTaskGroup(task.node_id, existing_tg.node_id, self.node_id)
234
235 # Set the TG first, as setting it might change the return value of node_id!
236 task.task_group = weakref.proxy(self)
237 key = task.node_id
238
239 if key in self.children:
240 node_type = "Task" if hasattr(task, "task_id") else "Task Group"
241 raise DuplicateTaskIdFound(f"{node_type} id '{key}' has already been added to the Dag")
242
243 if isinstance(task, TaskGroup):
244 if self.dag:
245 if task.dag is not None and self.dag is not task.dag:
246 raise ValueError(
247 "Cannot mix TaskGroups from different Dags: %s and %s",
248 self.dag,
249 task.dag,
250 )
251 task.dag = self.dag
252 if task.children:
253 raise ValueError("Cannot add a non-empty TaskGroup")
254
255 self.children[key] = task
256 return task
257
258 def _remove(self, task: DAGNode) -> None:
259 key = task.node_id
260
261 if key not in self.children:
262 raise KeyError(f"Node id {key!r} not part of this task group")
263
264 self.used_group_ids.remove(key)
265 del self.children[key]
266
267 @property
268 def group_id(self) -> str | None:
269 """group_id of this TaskGroup."""
270 if (
271 self._group_id
272 and self.parent_group
273 and self.parent_group.prefix_group_id
274 and self.parent_group._group_id
275 ):
276 # defer to parent whether it adds a prefix
277 return self.parent_group.child_id(self._group_id)
278 return self._group_id
279
280 @property
281 def label(self) -> str | None:
282 """group_id excluding parent's group_id used as the node label in UI."""
283 return self.group_display_name or self._group_id
284
285 def update_relative(
286 self,
287 other: DependencyMixin,
288 upstream: bool = True,
289 edge_modifier: EdgeModifier | None = None,
290 ) -> None:
291 """
292 Override TaskMixin.update_relative.
293
294 Update upstream_group_ids/downstream_group_ids/upstream_task_ids/downstream_task_ids
295 accordingly so that we can reduce the number of edges when displaying Graph view.
296 """
297 if isinstance(other, TaskGroup):
298 # Handles setting relationship between a TaskGroup and another TaskGroup
299 if upstream:
300 parent, child = (self, other)
301 if edge_modifier:
302 edge_modifier.add_edge_info(self.dag, other.downstream_join_id, self.upstream_join_id)
303 else:
304 parent, child = (other, self)
305 if edge_modifier:
306 edge_modifier.add_edge_info(self.dag, self.downstream_join_id, other.upstream_join_id)
307
308 parent.upstream_group_ids.add(child.group_id)
309 child.downstream_group_ids.add(parent.group_id)
310 else:
311 # Handles setting relationship between a TaskGroup and a task
312 for task in other.roots:
313 if not isinstance(task, DAGNode):
314 raise RuntimeError(
315 "Relationships can only be set between TaskGroup "
316 f"or operators; received {task.__class__.__name__}"
317 )
318
319 # Do not set a relationship between a TaskGroup and a Label's roots
320 if self == task:
321 continue
322
323 if upstream:
324 self.upstream_task_ids.add(task.node_id)
325 if edge_modifier:
326 edge_modifier.add_edge_info(self.dag, task.node_id, self.upstream_join_id)
327 else:
328 self.downstream_task_ids.add(task.node_id)
329 if edge_modifier:
330 edge_modifier.add_edge_info(self.dag, self.downstream_join_id, task.node_id)
331
332 def _set_relatives(
333 self,
334 task_or_task_list: DependencyMixin | Sequence[DependencyMixin],
335 upstream: bool = False,
336 edge_modifier: EdgeModifier | None = None,
337 ) -> None:
338 """
339 Call set_upstream/set_downstream for all root/leaf tasks within this TaskGroup.
340
341 Update upstream_group_ids/downstream_group_ids/upstream_task_ids/downstream_task_ids.
342 """
343 if not isinstance(task_or_task_list, Sequence):
344 task_or_task_list = [task_or_task_list]
345
346 # Helper function to find leaves from a task list or task group
347 def find_leaves(group_or_task) -> list[Any]:
348 while group_or_task:
349 group_or_task_leaves = list(group_or_task.get_leaves())
350 if group_or_task_leaves:
351 return group_or_task_leaves
352 if group_or_task.upstream_task_ids:
353 upstream_task_ids_list = list(group_or_task.upstream_task_ids)
354 return [self.dag.get_task(task_id) for task_id in upstream_task_ids_list]
355 group_or_task = group_or_task.parent_group
356 return []
357
358 # Check if the current TaskGroup is empty
359 leaves = find_leaves(self)
360
361 for task_like in task_or_task_list:
362 self.update_relative(task_like, upstream, edge_modifier=edge_modifier)
363
364 if upstream:
365 for task in self.get_roots():
366 task.set_upstream(task_or_task_list)
367 else:
368 for task in leaves: # Use the fetched leaves
369 task.set_downstream(task_or_task_list)
370
371 def __enter__(self) -> TaskGroup:
372 from airflow.sdk.definitions._internal.contextmanager import TaskGroupContext
373
374 TaskGroupContext.push(self)
375 return self
376
377 def __exit__(self, _type, _value, _tb):
378 from airflow.sdk.definitions._internal.contextmanager import TaskGroupContext
379
380 TaskGroupContext.pop()
381
382 def has_task(self, task: BaseOperator) -> bool:
383 """Return True if this TaskGroup or its children TaskGroups contains the given task."""
384 if task.task_id in self.children:
385 return True
386
387 return any(child.has_task(task) for child in self.children.values() if isinstance(child, TaskGroup))
388
389 @property
390 def roots(self) -> list[BaseOperator]:
391 """Required by DependencyMixin."""
392 return list(self.get_roots())
393
394 @property
395 def leaves(self) -> list[BaseOperator]:
396 """Required by DependencyMixin."""
397 return list(self.get_leaves())
398
399 def get_roots(self) -> Generator[BaseOperator, None, None]:
400 """Return a generator of tasks with no upstream dependencies within the TaskGroup."""
401 tasks = list(self)
402 ids = {x.task_id for x in tasks}
403 for task in tasks:
404 if task.upstream_task_ids.isdisjoint(ids):
405 yield task
406
407 def get_leaves(self) -> Generator[BaseOperator, None, None]:
408 """Return a generator of tasks with no downstream dependencies within the TaskGroup."""
409 tasks = list(self)
410 ids = {x.task_id for x in tasks}
411
412 def has_non_teardown_downstream(task, exclude: str):
413 for down_task in task.downstream_list:
414 if down_task.task_id == exclude:
415 continue
416 if down_task.task_id not in ids:
417 continue
418 if not down_task.is_teardown:
419 return True
420 return False
421
422 def recurse_for_first_non_teardown(task):
423 for upstream_task in task.upstream_list:
424 if upstream_task.task_id not in ids:
425 # upstream task is not in task group
426 continue
427 elif upstream_task.is_teardown:
428 yield from recurse_for_first_non_teardown(upstream_task)
429 elif task.is_teardown and upstream_task.is_setup:
430 # don't go through the teardown-to-setup path
431 continue
432 # return unless upstream task already has non-teardown downstream in group
433 elif not has_non_teardown_downstream(upstream_task, exclude=task.task_id):
434 yield upstream_task
435
436 for task in tasks:
437 if task.downstream_task_ids.isdisjoint(ids):
438 if not task.is_teardown:
439 yield task
440 else:
441 yield from recurse_for_first_non_teardown(task)
442
443 def child_id(self, label):
444 """Prefix label with group_id if prefix_group_id is True. Otherwise return the label as-is."""
445 if self.prefix_group_id:
446 group_id = self.group_id
447 if group_id:
448 return f"{group_id}.{label}"
449
450 return label
451
452 @property
453 def upstream_join_id(self) -> str:
454 """
455 Creates a unique ID for upstream dependencies of this TaskGroup.
456
457 If this TaskGroup has immediate upstream TaskGroups or tasks, a proxy node called
458 upstream_join_id will be created in Graph view to join the outgoing edges from this
459 TaskGroup to reduce the total number of edges needed to be displayed.
460 """
461 return f"{self.group_id}.upstream_join_id"
462
463 @property
464 def downstream_join_id(self) -> str:
465 """
466 Creates a unique ID for downstream dependencies of this TaskGroup.
467
468 If this TaskGroup has immediate downstream TaskGroups or tasks, a proxy node called
469 downstream_join_id will be created in Graph view to join the outgoing edges from this
470 TaskGroup to reduce the total number of edges needed to be displayed.
471 """
472 return f"{self.group_id}.downstream_join_id"
473
474 def get_task_group_dict(self) -> dict[str, TaskGroup]:
475 """Return a flat dictionary of group_id: TaskGroup."""
476 task_group_map = {}
477
478 def build_map(task_group):
479 if not isinstance(task_group, TaskGroup):
480 return
481
482 task_group_map[task_group.group_id] = task_group
483
484 for child in task_group.children.values():
485 build_map(child)
486
487 build_map(self)
488 return task_group_map
489
490 def get_child_by_label(self, label: str) -> DAGNode:
491 """Get a child task/TaskGroup by its label (i.e. task_id/group_id without the group_id prefix)."""
492 return self.children[self.child_id(label)]
493
494 def serialize_for_task_group(self) -> tuple[DagAttributeTypes, Any]:
495 """Serialize task group; required by DagNode."""
496 from airflow.serialization.enums import DagAttributeTypes
497 from airflow.serialization.serialized_objects import TaskGroupSerialization
498
499 return (
500 DagAttributeTypes.TASK_GROUP,
501 TaskGroupSerialization.serialize_task_group(self),
502 )
503
504 def hierarchical_alphabetical_sort(self):
505 """
506 Sort children in hierarchical alphabetical order.
507
508 - groups in alphabetical order first
509 - tasks in alphabetical order after them.
510
511 :return: list of tasks in hierarchical alphabetical order
512 """
513 return sorted(
514 self.children.values(),
515 key=lambda node: (not isinstance(node, TaskGroup), node.node_id),
516 )
517
518 def topological_sort(self):
519 """
520 Sorts children in topographical order, such that a task comes after any of its upstream dependencies.
521
522 :return: list of tasks in topological order
523 """
524 # This uses a modified version of Kahn's Topological Sort algorithm to
525 # not have to pre-compute the "in-degree" of the nodes.
526 graph_unsorted = copy.copy(self.children)
527
528 graph_sorted: list[DAGNode] = []
529
530 # special case
531 if not self.children:
532 return graph_sorted
533
534 # Run until the unsorted graph is empty.
535 while graph_unsorted:
536 # Go through each of the node/edges pairs in the unsorted graph. If a set of edges doesn't contain
537 # any nodes that haven't been resolved, that is, that are still in the unsorted graph, remove the
538 # pair from the unsorted graph, and append it to the sorted graph. Note here that by using
539 # the values() method for iterating, a copy of the unsorted graph is used, allowing us to modify
540 # the unsorted graph as we move through it.
541 #
542 # We also keep a flag for checking that graph is acyclic, which is true if any nodes are resolved
543 # during each pass through the graph. If not, we need to exit as the graph therefore can't be
544 # sorted.
545 acyclic = False
546 for node in list(graph_unsorted.values()):
547 for edge in node.upstream_list:
548 if edge.node_id in graph_unsorted:
549 break
550 # Check for task's group is a child (or grand child) of this TG,
551 tg = edge.task_group
552 while tg:
553 if tg.node_id in graph_unsorted:
554 break
555 tg = tg.parent_group
556
557 if tg:
558 # We are already going to visit that TG
559 break
560 else:
561 acyclic = True
562 del graph_unsorted[node.node_id]
563 graph_sorted.append(node)
564
565 if not acyclic:
566 raise AirflowDagCycleException(f"A cyclic dependency occurred in dag: {self.dag_id}")
567
568 return graph_sorted
569
570 def iter_mapped_task_groups(self) -> Iterator[MappedTaskGroup]:
571 """
572 Return mapped task groups in the hierarchy.
573
574 Groups are returned from the closest to the outmost. If *self* is a
575 mapped task group, it is returned first.
576
577 :meta private:
578 """
579 group: TaskGroup | None = self
580 while group is not None:
581 if isinstance(group, MappedTaskGroup):
582 yield group
583 group = group.parent_group
584
585 def iter_tasks(self) -> Iterator[AbstractOperator]:
586 """Return an iterator of the child tasks."""
587 from airflow.sdk.definitions._internal.abstractoperator import AbstractOperator
588
589 groups_to_visit = [self]
590
591 while groups_to_visit:
592 visiting = groups_to_visit.pop(0)
593
594 for child in visiting.children.values():
595 if isinstance(child, AbstractOperator):
596 yield child
597 elif isinstance(child, TaskGroup):
598 groups_to_visit.append(child)
599 else:
600 raise ValueError(
601 f"Encountered a DAGNode that is not a TaskGroup or an "
602 f"AbstractOperator: {type(child).__module__}.{type(child)}"
603 )
604
605
606@attrs.define(kw_only=True, repr=False)
607class MappedTaskGroup(TaskGroup):
608 """
609 A mapped task group.
610
611 This doesn't really do anything special, just holds some additional metadata
612 for expansion later.
613
614 Don't instantiate this class directly; call *expand* or *expand_kwargs* on
615 a ``@task_group`` function instead.
616 """
617
618 _expand_input: DictOfListsExpandInput | ListOfDictsExpandInput = attrs.field(alias="expand_input")
619
620 def __iter__(self):
621 for child in self.children.values():
622 if getattr(child, "trigger_rule", None) == TriggerRule.ALWAYS:
623 raise ValueError(
624 "Task-generated mapping within a mapped task group is not "
625 "allowed with trigger rule 'always'"
626 )
627 yield from self._iter_child(child)
628
629 def __exit__(self, exc_type, exc_val, exc_tb):
630 for op, _ in self._expand_input.iter_references():
631 self.set_upstream(op)
632 super().__exit__(exc_type, exc_val, exc_tb)
633
634 def iter_mapped_dependencies(self) -> Iterator[Operator]:
635 """Upstream dependencies that provide XComs used by this mapped task group."""
636 from airflow.sdk.definitions.xcom_arg import XComArg
637
638 for op, _ in XComArg.iter_xcom_references(self._expand_input):
639 yield op