1#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18"""A collection of closely related tasks on the same Dag that should be grouped together visually."""
19
20from __future__ import annotations
21
22import copy
23import re
24import weakref
25from collections import deque
26from collections.abc import Generator, Iterator, Sequence
27from typing import TYPE_CHECKING, Any
28
29import attrs
30
31from airflow.sdk import TriggerRule
32from airflow.sdk.definitions._internal.node import DAGNode, validate_group_key
33from airflow.sdk.exceptions import (
34 AirflowDagCycleException,
35 DuplicateTaskIdFound,
36 TaskAlreadyInTaskGroup,
37)
38
39if TYPE_CHECKING:
40 from airflow.sdk.api.datamodels._generated import DagAttributeTypes
41 from airflow.sdk.bases.operator import BaseOperator
42 from airflow.sdk.definitions._internal.abstractoperator import AbstractOperator
43 from airflow.sdk.definitions._internal.expandinput import DictOfListsExpandInput, ListOfDictsExpandInput
44 from airflow.sdk.definitions._internal.mixins import DependencyMixin
45 from airflow.sdk.definitions.dag import DAG
46 from airflow.sdk.definitions.edges import EdgeModifier
47 from airflow.sdk.types import Operator
48
49
50def _default_parent_group() -> TaskGroup | None:
51 from airflow.sdk.definitions._internal.contextmanager import TaskGroupContext
52
53 return TaskGroupContext.get_current()
54
55
56def _parent_used_group_ids(tg: TaskGroup) -> set:
57 if tg.parent_group:
58 return tg.parent_group.used_group_ids
59 return set()
60
61
62# This could be achieved with `@dag.default` and make this a method, but for some unknown reason when we do
63# that it makes Mypy (1.9.0 and 1.13.0 tested) seem to entirely loose track that this is an Attrs class. So
64# we've gone with this and moved on with our lives, mypy is to much of a dark beast to battle over this.
65def _default_dag(instance: TaskGroup):
66 from airflow.sdk.definitions._internal.contextmanager import DagContext
67
68 if (pg := instance.parent_group) is not None:
69 return pg.dag
70 return DagContext.get_current()
71
72
73# Mypy does not like a lambda for some reason. An explicit annotated function makes it happy.
74def _validate_group_id(instance, attribute, value: str) -> None:
75 validate_group_key(value)
76
77
78@attrs.define(repr=False)
79class TaskGroup(DAGNode):
80 """
81 A collection of tasks.
82
83 When set_downstream() or set_upstream() are called on the TaskGroup, it is applied across
84 all tasks within the group if necessary.
85
86 :param group_id: a unique, meaningful id for the TaskGroup. group_id must not conflict
87 with group_id of TaskGroup or task_id of tasks in the Dag. Root TaskGroup has group_id
88 set to None.
89 :param prefix_group_id: If set to True, child task_id and group_id will be prefixed with
90 this TaskGroup's group_id. If set to False, child task_id and group_id are not prefixed.
91 Default is True.
92 :param parent_group: The parent TaskGroup of this TaskGroup. parent_group is set to None
93 for the root TaskGroup.
94 :param dag: The Dag that this TaskGroup belongs to.
95 :param default_args: A dictionary of default parameters to be used
96 as constructor keyword parameters when initialising operators,
97 will override default_args defined in the Dag level.
98 Note that operators have the same hook, and precede those defined
99 here, meaning that if your dict contains `'depends_on_past': True`
100 here and `'depends_on_past': False` in the operator's call
101 `default_args`, the actual value will be `False`.
102 :param tooltip: The tooltip of the TaskGroup node when displayed in the UI
103 :param ui_color: The fill color of the TaskGroup node when displayed in the UI
104 :param ui_fgcolor: The label color of the TaskGroup node when displayed in the UI
105 :param add_suffix_on_collision: If this task group name already exists,
106 automatically add `__1` etc suffixes
107 :param group_display_name: If set, this will be the display name for the TaskGroup node in the UI.
108 """
109
110 _group_id: str | None = attrs.field(
111 validator=attrs.validators.optional(_validate_group_id),
112 # This is the default behaviour for attrs, but by specifying this it makes IDEs happier
113 alias="group_id",
114 )
115 group_display_name: str = attrs.field(default="", validator=attrs.validators.instance_of(str))
116 prefix_group_id: bool = attrs.field(default=True)
117 parent_group: TaskGroup | None = attrs.field(factory=_default_parent_group)
118 dag: DAG = attrs.field(default=attrs.Factory(_default_dag, takes_self=True))
119 default_args: dict[str, Any] = attrs.field(factory=dict, converter=copy.deepcopy)
120 tooltip: str = attrs.field(default="", validator=attrs.validators.instance_of(str))
121 children: dict[str, DAGNode] = attrs.field(factory=dict, init=False)
122
123 upstream_group_ids: set[str | None] = attrs.field(factory=set, init=False)
124 downstream_group_ids: set[str | None] = attrs.field(factory=set, init=False)
125 upstream_task_ids: set[str] = attrs.field(factory=set, init=False)
126 downstream_task_ids: set[str] = attrs.field(factory=set, init=False)
127
128 used_group_ids: set[str] = attrs.field(
129 default=attrs.Factory(_parent_used_group_ids, takes_self=True),
130 init=False,
131 on_setattr=attrs.setters.frozen,
132 )
133
134 ui_color: str = attrs.field(default="CornflowerBlue", validator=attrs.validators.instance_of(str))
135 ui_fgcolor: str = attrs.field(default="#000", validator=attrs.validators.instance_of(str))
136
137 add_suffix_on_collision: bool = False
138
139 @dag.validator
140 def _validate_dag(self, _attr, dag):
141 if not dag:
142 raise RuntimeError("TaskGroup can only be used inside a dag")
143
144 def __attrs_post_init__(self):
145 # TODO: If attrs supported init only args we could use that here
146 # https://github.com/python-attrs/attrs/issues/342
147 self._check_for_group_id_collisions(self.add_suffix_on_collision)
148
149 if self._group_id and not self.parent_group and self.dag:
150 # Support `tg = TaskGroup(x, dag=dag)`
151 self.parent_group = self.dag.task_group
152
153 if self.parent_group:
154 self.parent_group.add(self)
155 if self.parent_group.default_args:
156 self.default_args = {
157 **self.parent_group.default_args,
158 **self.default_args,
159 }
160
161 if self._group_id:
162 self.used_group_ids.add(self.group_id)
163 self.used_group_ids.add(self.downstream_join_id)
164 self.used_group_ids.add(self.upstream_join_id)
165
166 def _check_for_group_id_collisions(self, add_suffix_on_collision: bool):
167 if self._group_id is None:
168 return
169 # if given group_id already used assign suffix by incrementing largest used suffix integer
170 # Example : task_group ==> task_group__1 -> task_group__2 -> task_group__3
171 if self.group_id in self.used_group_ids:
172 if not add_suffix_on_collision:
173 raise DuplicateTaskIdFound(f"group_id '{self._group_id}' has already been added to the DAG")
174 base = re.split(r"__\d+$", self._group_id)[0]
175 suffixes = sorted(
176 int(re.split(r"^.+__", used_group_id)[1])
177 for used_group_id in self.used_group_ids
178 if used_group_id is not None and re.match(rf"^{base}__\d+$", used_group_id)
179 )
180 if not suffixes:
181 self._group_id += "__1"
182 else:
183 self._group_id = f"{base}__{suffixes[-1] + 1}"
184
185 @classmethod
186 def create_root(cls, dag: DAG) -> TaskGroup:
187 """Create a root TaskGroup with no group_id or parent."""
188 return cls(group_id=None, dag=dag, parent_group=None)
189
190 @property
191 def node_id(self):
192 return self.group_id
193
194 @property
195 def is_root(self) -> bool:
196 """Returns True if this TaskGroup is the root TaskGroup. Otherwise False."""
197 return not self._group_id
198
199 @property
200 def task_group(self) -> TaskGroup | None:
201 return self.parent_group
202
203 @task_group.setter
204 def task_group(self, value: TaskGroup | None):
205 self.parent_group = value
206
207 def __iter__(self):
208 for child in self.children.values():
209 yield from self._iter_child(child)
210
211 @staticmethod
212 def _iter_child(child):
213 """Iterate over the children of this TaskGroup."""
214 if isinstance(child, TaskGroup):
215 yield from child
216 else:
217 yield child
218
219 def add(self, task: DAGNode) -> DAGNode:
220 """
221 Add a task or TaskGroup to this TaskGroup.
222
223 :meta private:
224 """
225 from airflow.sdk.definitions._internal.abstractoperator import AbstractOperator
226 from airflow.sdk.definitions._internal.contextmanager import TaskGroupContext
227
228 if TaskGroupContext.active:
229 if task.task_group and task.task_group != self:
230 task.task_group.children.pop(task.node_id, None)
231 task.task_group = self
232 existing_tg = task.task_group
233 if isinstance(task, AbstractOperator) and existing_tg is not None and existing_tg != self:
234 raise TaskAlreadyInTaskGroup(task.node_id, existing_tg.node_id, self.node_id)
235
236 # Set the TG first, as setting it might change the return value of node_id!
237 task.task_group = weakref.proxy(self)
238 key = task.node_id
239
240 if key in self.children:
241 node_type = "Task" if hasattr(task, "task_id") else "Task Group"
242 raise DuplicateTaskIdFound(f"{node_type} id '{key}' has already been added to the Dag")
243
244 if isinstance(task, TaskGroup):
245 if self.dag:
246 if task.dag is not None and self.dag is not task.dag:
247 raise ValueError(
248 "Cannot mix TaskGroups from different Dags: %s and %s",
249 self.dag,
250 task.dag,
251 )
252 task.dag = self.dag
253 if task.children:
254 raise ValueError("Cannot add a non-empty TaskGroup")
255
256 self.children[key] = task
257 return task
258
259 def _remove(self, task: DAGNode) -> None:
260 key = task.node_id
261
262 if key not in self.children:
263 raise KeyError(f"Node id {key!r} not part of this task group")
264
265 self.used_group_ids.remove(key)
266 del self.children[key]
267
268 @property
269 def group_id(self) -> str | None:
270 """group_id of this TaskGroup."""
271 if (
272 self._group_id
273 and self.parent_group
274 and self.parent_group.prefix_group_id
275 and self.parent_group._group_id
276 ):
277 # defer to parent whether it adds a prefix
278 return self.parent_group.child_id(self._group_id)
279 return self._group_id
280
281 @property
282 def label(self) -> str | None:
283 """group_id excluding parent's group_id used as the node label in UI."""
284 return self.group_display_name or self._group_id
285
286 def update_relative(
287 self,
288 other: DependencyMixin,
289 upstream: bool = True,
290 edge_modifier: EdgeModifier | None = None,
291 ) -> None:
292 """
293 Override TaskMixin.update_relative.
294
295 Update upstream_group_ids/downstream_group_ids/upstream_task_ids/downstream_task_ids
296 accordingly so that we can reduce the number of edges when displaying Graph view.
297 """
298 if isinstance(other, TaskGroup):
299 # Handles setting relationship between a TaskGroup and another TaskGroup
300 if upstream:
301 parent, child = (self, other)
302 if edge_modifier:
303 edge_modifier.add_edge_info(self.dag, other.downstream_join_id, self.upstream_join_id)
304 else:
305 parent, child = (other, self)
306 if edge_modifier:
307 edge_modifier.add_edge_info(self.dag, self.downstream_join_id, other.upstream_join_id)
308
309 parent.upstream_group_ids.add(child.group_id)
310 child.downstream_group_ids.add(parent.group_id)
311 else:
312 # Handles setting relationship between a TaskGroup and a task
313 for task in other.roots:
314 if not isinstance(task, DAGNode):
315 raise RuntimeError(
316 "Relationships can only be set between TaskGroup "
317 f"or operators; received {task.__class__.__name__}"
318 )
319
320 # Do not set a relationship between a TaskGroup and a Label's roots
321 if self == task:
322 continue
323
324 if upstream:
325 self.upstream_task_ids.add(task.node_id)
326 if edge_modifier:
327 edge_modifier.add_edge_info(self.dag, task.node_id, self.upstream_join_id)
328 else:
329 self.downstream_task_ids.add(task.node_id)
330 if edge_modifier:
331 edge_modifier.add_edge_info(self.dag, self.downstream_join_id, task.node_id)
332
333 def _set_relatives(
334 self,
335 task_or_task_list: DependencyMixin | Sequence[DependencyMixin],
336 upstream: bool = False,
337 edge_modifier: EdgeModifier | None = None,
338 ) -> None:
339 """
340 Call set_upstream/set_downstream for all root/leaf tasks within this TaskGroup.
341
342 Update upstream_group_ids/downstream_group_ids/upstream_task_ids/downstream_task_ids.
343 """
344 if not isinstance(task_or_task_list, Sequence):
345 task_or_task_list = [task_or_task_list]
346
347 # Helper function to find leaves from a task list or task group
348 def find_leaves(group_or_task) -> list[Any]:
349 while group_or_task:
350 group_or_task_leaves = list(group_or_task.get_leaves())
351 if group_or_task_leaves:
352 return group_or_task_leaves
353 if group_or_task.upstream_task_ids:
354 upstream_task_ids_list = list(group_or_task.upstream_task_ids)
355 return [self.dag.get_task(task_id) for task_id in upstream_task_ids_list]
356 group_or_task = group_or_task.parent_group
357 return []
358
359 # Check if the current TaskGroup is empty
360 leaves = find_leaves(self)
361
362 for task_like in task_or_task_list:
363 self.update_relative(task_like, upstream, edge_modifier=edge_modifier)
364
365 if upstream:
366 for task in self.get_roots():
367 task.set_upstream(task_or_task_list)
368 else:
369 for task in leaves: # Use the fetched leaves
370 task.set_downstream(task_or_task_list)
371
372 def __enter__(self) -> TaskGroup:
373 from airflow.sdk.definitions._internal.contextmanager import TaskGroupContext
374
375 TaskGroupContext.push(self)
376 return self
377
378 def __exit__(self, _type, _value, _tb):
379 from airflow.sdk.definitions._internal.contextmanager import TaskGroupContext
380
381 TaskGroupContext.pop()
382
383 def has_task(self, task: BaseOperator) -> bool:
384 """Return True if this TaskGroup or its children TaskGroups contains the given task."""
385 if task.task_id in self.children:
386 return True
387
388 return any(child.has_task(task) for child in self.children.values() if isinstance(child, TaskGroup))
389
390 @property
391 def roots(self) -> list[BaseOperator]:
392 """Required by DependencyMixin."""
393 return list(self.get_roots())
394
395 @property
396 def leaves(self) -> list[BaseOperator]:
397 """Required by DependencyMixin."""
398 return list(self.get_leaves())
399
400 def get_roots(self) -> Generator[BaseOperator, None, None]:
401 """Return a generator of tasks with no upstream dependencies within the TaskGroup."""
402 tasks = list(self)
403 ids = {x.task_id for x in tasks}
404 for task in tasks:
405 if task.upstream_task_ids.isdisjoint(ids):
406 yield task
407
408 def get_leaves(self) -> Generator[BaseOperator, None, None]:
409 """Return a generator of tasks with no downstream dependencies within the TaskGroup."""
410 tasks = list(self)
411 ids = {x.task_id for x in tasks}
412
413 def has_non_teardown_downstream(task, exclude: str):
414 for down_task in task.downstream_list:
415 if down_task.task_id == exclude:
416 continue
417 if down_task.task_id not in ids:
418 continue
419 if not down_task.is_teardown:
420 return True
421 return False
422
423 def recurse_for_first_non_teardown(task):
424 for upstream_task in task.upstream_list:
425 if upstream_task.task_id not in ids:
426 # upstream task is not in task group
427 continue
428 elif upstream_task.is_teardown:
429 yield from recurse_for_first_non_teardown(upstream_task)
430 elif task.is_teardown and upstream_task.is_setup:
431 # don't go through the teardown-to-setup path
432 continue
433 # return unless upstream task already has non-teardown downstream in group
434 elif not has_non_teardown_downstream(upstream_task, exclude=task.task_id):
435 yield upstream_task
436
437 for task in tasks:
438 if task.downstream_task_ids.isdisjoint(ids):
439 if not task.is_teardown:
440 yield task
441 else:
442 yield from recurse_for_first_non_teardown(task)
443
444 def child_id(self, label):
445 """Prefix label with group_id if prefix_group_id is True. Otherwise return the label as-is."""
446 if self.prefix_group_id:
447 group_id = self.group_id
448 if group_id:
449 return f"{group_id}.{label}"
450
451 return label
452
453 @property
454 def upstream_join_id(self) -> str:
455 """
456 Creates a unique ID for upstream dependencies of this TaskGroup.
457
458 If this TaskGroup has immediate upstream TaskGroups or tasks, a proxy node called
459 upstream_join_id will be created in Graph view to join the outgoing edges from this
460 TaskGroup to reduce the total number of edges needed to be displayed.
461 """
462 return f"{self.group_id}.upstream_join_id"
463
464 @property
465 def downstream_join_id(self) -> str:
466 """
467 Creates a unique ID for downstream dependencies of this TaskGroup.
468
469 If this TaskGroup has immediate downstream TaskGroups or tasks, a proxy node called
470 downstream_join_id will be created in Graph view to join the outgoing edges from this
471 TaskGroup to reduce the total number of edges needed to be displayed.
472 """
473 return f"{self.group_id}.downstream_join_id"
474
475 def get_task_group_dict(self) -> dict[str, TaskGroup]:
476 """Return a flat dictionary of group_id: TaskGroup."""
477 task_group_map = {}
478
479 def build_map(task_group):
480 if not isinstance(task_group, TaskGroup):
481 return
482
483 task_group_map[task_group.group_id] = task_group
484
485 for child in task_group.children.values():
486 build_map(child)
487
488 build_map(self)
489 return task_group_map
490
491 def get_child_by_label(self, label: str) -> DAGNode:
492 """Get a child task/TaskGroup by its label (i.e. task_id/group_id without the group_id prefix)."""
493 return self.children[self.child_id(label)]
494
495 def serialize_for_task_group(self) -> tuple[DagAttributeTypes, Any]:
496 """Serialize task group; required by DagNode."""
497 from airflow.sdk.api.datamodels._generated import DagAttributeTypes
498 from airflow.serialization.serialized_objects import TaskGroupSerialization
499
500 return (
501 DagAttributeTypes.TASK_GROUP,
502 TaskGroupSerialization.serialize_task_group(self),
503 )
504
505 def hierarchical_alphabetical_sort(self):
506 """
507 Sort children in hierarchical alphabetical order.
508
509 - groups in alphabetical order first
510 - tasks in alphabetical order after them.
511
512 :return: list of tasks in hierarchical alphabetical order
513 """
514 return sorted(
515 self.children.values(),
516 key=lambda node: (not isinstance(node, TaskGroup), node.node_id),
517 )
518
519 def topological_sort(self):
520 """
521 Sorts children in topographical order, such that a task comes after any of its upstream dependencies.
522
523 :return: list of tasks in topological order
524 """
525 # This uses a modified version of Kahn's Topological Sort algorithm to
526 # not have to pre-compute the "in-degree" of the nodes.
527 graph_unsorted = copy.copy(self.children)
528
529 graph_sorted: list[DAGNode] = []
530
531 # special case
532 if not self.children:
533 return graph_sorted
534
535 # Run until the unsorted graph is empty.
536 while graph_unsorted:
537 # Go through each of the node/edges pairs in the unsorted graph. If a set of edges doesn't contain
538 # any nodes that haven't been resolved, that is, that are still in the unsorted graph, remove the
539 # pair from the unsorted graph, and append it to the sorted graph. Note here that by using
540 # the values() method for iterating, a copy of the unsorted graph is used, allowing us to modify
541 # the unsorted graph as we move through it.
542 #
543 # We also keep a flag for checking that graph is acyclic, which is true if any nodes are resolved
544 # during each pass through the graph. If not, we need to exit as the graph therefore can't be
545 # sorted.
546 acyclic = False
547 for node in list(graph_unsorted.values()):
548 for edge in node.upstream_list:
549 if edge.node_id in graph_unsorted:
550 break
551 # Check for task's group is a child (or grand child) of this TG,
552 tg = edge.task_group
553 while tg:
554 if tg.node_id in graph_unsorted:
555 break
556 tg = tg.parent_group
557
558 if tg:
559 # We are already going to visit that TG
560 break
561 else:
562 acyclic = True
563 del graph_unsorted[node.node_id]
564 graph_sorted.append(node)
565
566 if not acyclic:
567 raise AirflowDagCycleException(f"A cyclic dependency occurred in dag: {self.dag_id}")
568
569 return graph_sorted
570
571 def iter_mapped_task_groups(self) -> Iterator[MappedTaskGroup]:
572 """
573 Return mapped task groups in the hierarchy.
574
575 Groups are returned from the closest to the outmost. If *self* is a
576 mapped task group, it is returned first.
577
578 :meta private:
579 """
580 group: TaskGroup | None = self
581 while group is not None:
582 if isinstance(group, MappedTaskGroup):
583 yield group
584 group = group.parent_group
585
586 def iter_tasks(self) -> Iterator[AbstractOperator]:
587 """Return an iterator of the child tasks."""
588 from airflow.sdk.definitions._internal.abstractoperator import AbstractOperator
589
590 groups_to_visit = deque([self])
591
592 while groups_to_visit:
593 visiting = groups_to_visit.popleft()
594
595 for child in visiting.children.values():
596 if isinstance(child, AbstractOperator):
597 yield child
598 elif isinstance(child, TaskGroup):
599 groups_to_visit.append(child)
600 else:
601 raise ValueError(
602 f"Encountered a DAGNode that is not a TaskGroup or an "
603 f"AbstractOperator: {type(child).__module__}.{type(child)}"
604 )
605
606
607@attrs.define(kw_only=True, repr=False)
608class MappedTaskGroup(TaskGroup):
609 """
610 A mapped task group.
611
612 This doesn't really do anything special, just holds some additional metadata
613 for expansion later.
614
615 Don't instantiate this class directly; call *expand* or *expand_kwargs* on
616 a ``@task_group`` function instead.
617 """
618
619 _expand_input: DictOfListsExpandInput | ListOfDictsExpandInput = attrs.field(alias="expand_input")
620
621 def __iter__(self):
622 for child in self.children.values():
623 if getattr(child, "trigger_rule", None) == TriggerRule.ALWAYS:
624 raise ValueError(
625 "Task-generated mapping within a mapped task group is not "
626 "allowed with trigger rule 'always'"
627 )
628 yield from self._iter_child(child)
629
630 def __exit__(self, exc_type, exc_val, exc_tb):
631 for op, _ in self._expand_input.iter_references():
632 self.set_upstream(op)
633 super().__exit__(exc_type, exc_val, exc_tb)
634
635 def iter_mapped_dependencies(self) -> Iterator[Operator]:
636 """Upstream dependencies that provide XComs used by this mapped task group."""
637 from airflow.sdk.definitions.xcom_arg import XComArg
638
639 for op, _ in XComArg.iter_xcom_references(self._expand_input):
640 yield op