1#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18"""A collection of closely related tasks on the same Dag that should be grouped together visually."""
19
20from __future__ import annotations
21
22import copy
23import functools
24import operator
25import re
26import weakref
27from collections.abc import Generator, Iterator, Sequence
28from typing import TYPE_CHECKING, Any
29
30import attrs
31import methodtools
32
33from airflow.sdk import TriggerRule
34from airflow.sdk.definitions._internal.node import DAGNode, validate_group_key
35from airflow.sdk.exceptions import (
36 AirflowDagCycleException,
37 DuplicateTaskIdFound,
38 TaskAlreadyInTaskGroup,
39)
40
41if TYPE_CHECKING:
42 from airflow.sdk.bases.operator import BaseOperator
43 from airflow.sdk.definitions._internal.abstractoperator import AbstractOperator
44 from airflow.sdk.definitions._internal.expandinput import DictOfListsExpandInput, ListOfDictsExpandInput
45 from airflow.sdk.definitions._internal.mixins import DependencyMixin
46 from airflow.sdk.definitions.dag import DAG
47 from airflow.sdk.definitions.edges import EdgeModifier
48 from airflow.sdk.types import Operator
49 from airflow.serialization.enums import DagAttributeTypes
50
51
52def _default_parent_group() -> TaskGroup | None:
53 from airflow.sdk.definitions._internal.contextmanager import TaskGroupContext
54
55 return TaskGroupContext.get_current()
56
57
58def _parent_used_group_ids(tg: TaskGroup) -> set:
59 if tg.parent_group:
60 return tg.parent_group.used_group_ids
61 return set()
62
63
64# This could be achieved with `@dag.default` and make this a method, but for some unknown reason when we do
65# that it makes Mypy (1.9.0 and 1.13.0 tested) seem to entirely loose track that this is an Attrs class. So
66# we've gone with this and moved on with our lives, mypy is to much of a dark beast to battle over this.
67def _default_dag(instance: TaskGroup):
68 from airflow.sdk.definitions._internal.contextmanager import DagContext
69
70 if (pg := instance.parent_group) is not None:
71 return pg.dag
72 return DagContext.get_current()
73
74
75# Mypy does not like a lambda for some reason. An explicit annotated function makes it happy.
76def _validate_group_id(instance, attribute, value: str) -> None:
77 validate_group_key(value)
78
79
80@attrs.define(repr=False)
81class TaskGroup(DAGNode):
82 """
83 A collection of tasks.
84
85 When set_downstream() or set_upstream() are called on the TaskGroup, it is applied across
86 all tasks within the group if necessary.
87
88 :param group_id: a unique, meaningful id for the TaskGroup. group_id must not conflict
89 with group_id of TaskGroup or task_id of tasks in the Dag. Root TaskGroup has group_id
90 set to None.
91 :param prefix_group_id: If set to True, child task_id and group_id will be prefixed with
92 this TaskGroup's group_id. If set to False, child task_id and group_id are not prefixed.
93 Default is True.
94 :param parent_group: The parent TaskGroup of this TaskGroup. parent_group is set to None
95 for the root TaskGroup.
96 :param dag: The Dag that this TaskGroup belongs to.
97 :param default_args: A dictionary of default parameters to be used
98 as constructor keyword parameters when initialising operators,
99 will override default_args defined in the Dag level.
100 Note that operators have the same hook, and precede those defined
101 here, meaning that if your dict contains `'depends_on_past': True`
102 here and `'depends_on_past': False` in the operator's call
103 `default_args`, the actual value will be `False`.
104 :param tooltip: The tooltip of the TaskGroup node when displayed in the UI
105 :param ui_color: The fill color of the TaskGroup node when displayed in the UI
106 :param ui_fgcolor: The label color of the TaskGroup node when displayed in the UI
107 :param add_suffix_on_collision: If this task group name already exists,
108 automatically add `__1` etc suffixes
109 :param group_display_name: If set, this will be the display name for the TaskGroup node in the UI.
110 """
111
112 _group_id: str | None = attrs.field(
113 validator=attrs.validators.optional(_validate_group_id),
114 # This is the default behaviour for attrs, but by specifying this it makes IDEs happier
115 alias="group_id",
116 )
117 group_display_name: str = attrs.field(default="", validator=attrs.validators.instance_of(str))
118 prefix_group_id: bool = attrs.field(default=True)
119 parent_group: TaskGroup | None = attrs.field(factory=_default_parent_group)
120 dag: DAG = attrs.field(default=attrs.Factory(_default_dag, takes_self=True))
121 default_args: dict[str, Any] = attrs.field(factory=dict, converter=copy.deepcopy)
122 tooltip: str = attrs.field(default="", validator=attrs.validators.instance_of(str))
123 children: dict[str, DAGNode] = attrs.field(factory=dict, init=False)
124
125 upstream_group_ids: set[str | None] = attrs.field(factory=set, init=False)
126 downstream_group_ids: set[str | None] = attrs.field(factory=set, init=False)
127 upstream_task_ids: set[str] = attrs.field(factory=set, init=False)
128 downstream_task_ids: set[str] = attrs.field(factory=set, init=False)
129
130 used_group_ids: set[str] = attrs.field(
131 default=attrs.Factory(_parent_used_group_ids, takes_self=True),
132 init=False,
133 on_setattr=attrs.setters.frozen,
134 )
135
136 ui_color: str = attrs.field(default="CornflowerBlue", validator=attrs.validators.instance_of(str))
137 ui_fgcolor: str = attrs.field(default="#000", validator=attrs.validators.instance_of(str))
138
139 add_suffix_on_collision: bool = False
140
141 @dag.validator
142 def _validate_dag(self, _attr, dag):
143 if not dag:
144 raise RuntimeError("TaskGroup can only be used inside a dag")
145
146 def __attrs_post_init__(self):
147 # TODO: If attrs supported init only args we could use that here
148 # https://github.com/python-attrs/attrs/issues/342
149 self._check_for_group_id_collisions(self.add_suffix_on_collision)
150
151 if self._group_id and not self.parent_group and self.dag:
152 # Support `tg = TaskGroup(x, dag=dag)`
153 self.parent_group = self.dag.task_group
154
155 if self.parent_group:
156 self.parent_group.add(self)
157 if self.parent_group.default_args:
158 self.default_args = {
159 **self.parent_group.default_args,
160 **self.default_args,
161 }
162
163 if self._group_id:
164 self.used_group_ids.add(self.group_id)
165 self.used_group_ids.add(self.downstream_join_id)
166 self.used_group_ids.add(self.upstream_join_id)
167
168 def _check_for_group_id_collisions(self, add_suffix_on_collision: bool):
169 if self._group_id is None:
170 return
171 # if given group_id already used assign suffix by incrementing largest used suffix integer
172 # Example : task_group ==> task_group__1 -> task_group__2 -> task_group__3
173 if self.group_id in self.used_group_ids:
174 if not add_suffix_on_collision:
175 raise DuplicateTaskIdFound(f"group_id '{self._group_id}' has already been added to the DAG")
176 base = re.split(r"__\d+$", self._group_id)[0]
177 suffixes = sorted(
178 int(re.split(r"^.+__", used_group_id)[1])
179 for used_group_id in self.used_group_ids
180 if used_group_id is not None and re.match(rf"^{base}__\d+$", used_group_id)
181 )
182 if not suffixes:
183 self._group_id += "__1"
184 else:
185 self._group_id = f"{base}__{suffixes[-1] + 1}"
186
187 @classmethod
188 def create_root(cls, dag: DAG) -> TaskGroup:
189 """Create a root TaskGroup with no group_id or parent."""
190 return cls(group_id=None, dag=dag, parent_group=None)
191
192 @property
193 def node_id(self):
194 return self.group_id
195
196 @property
197 def is_root(self) -> bool:
198 """Returns True if this TaskGroup is the root TaskGroup. Otherwise False."""
199 return not self._group_id
200
201 @property
202 def task_group(self) -> TaskGroup | None:
203 return self.parent_group
204
205 @task_group.setter
206 def task_group(self, value: TaskGroup | None):
207 self.parent_group = value
208
209 def __iter__(self):
210 for child in self.children.values():
211 yield from self._iter_child(child)
212
213 @staticmethod
214 def _iter_child(child):
215 """Iterate over the children of this TaskGroup."""
216 if isinstance(child, TaskGroup):
217 yield from child
218 else:
219 yield child
220
221 def add(self, task: DAGNode) -> DAGNode:
222 """
223 Add a task or TaskGroup to this TaskGroup.
224
225 :meta private:
226 """
227 from airflow.sdk.definitions._internal.abstractoperator import AbstractOperator
228 from airflow.sdk.definitions._internal.contextmanager import TaskGroupContext
229
230 if TaskGroupContext.active:
231 if task.task_group and task.task_group != self:
232 task.task_group.children.pop(task.node_id, None)
233 task.task_group = self
234 existing_tg = task.task_group
235 if isinstance(task, AbstractOperator) and existing_tg is not None and existing_tg != self:
236 raise TaskAlreadyInTaskGroup(task.node_id, existing_tg.node_id, self.node_id)
237
238 # Set the TG first, as setting it might change the return value of node_id!
239 task.task_group = weakref.proxy(self)
240 key = task.node_id
241
242 if key in self.children:
243 node_type = "Task" if hasattr(task, "task_id") else "Task Group"
244 raise DuplicateTaskIdFound(f"{node_type} id '{key}' has already been added to the Dag")
245
246 if isinstance(task, TaskGroup):
247 if self.dag:
248 if task.dag is not None and self.dag is not task.dag:
249 raise ValueError(
250 "Cannot mix TaskGroups from different Dags: %s and %s",
251 self.dag,
252 task.dag,
253 )
254 task.dag = self.dag
255 if task.children:
256 raise ValueError("Cannot add a non-empty TaskGroup")
257
258 self.children[key] = task
259 return task
260
261 def _remove(self, task: DAGNode) -> None:
262 key = task.node_id
263
264 if key not in self.children:
265 raise KeyError(f"Node id {key!r} not part of this task group")
266
267 self.used_group_ids.remove(key)
268 del self.children[key]
269
270 @property
271 def group_id(self) -> str | None:
272 """group_id of this TaskGroup."""
273 if (
274 self._group_id
275 and self.parent_group
276 and self.parent_group.prefix_group_id
277 and self.parent_group._group_id
278 ):
279 # defer to parent whether it adds a prefix
280 return self.parent_group.child_id(self._group_id)
281 return self._group_id
282
283 @property
284 def label(self) -> str | None:
285 """group_id excluding parent's group_id used as the node label in UI."""
286 return self.group_display_name or self._group_id
287
288 def update_relative(
289 self,
290 other: DependencyMixin,
291 upstream: bool = True,
292 edge_modifier: EdgeModifier | None = None,
293 ) -> None:
294 """
295 Override TaskMixin.update_relative.
296
297 Update upstream_group_ids/downstream_group_ids/upstream_task_ids/downstream_task_ids
298 accordingly so that we can reduce the number of edges when displaying Graph view.
299 """
300 if isinstance(other, TaskGroup):
301 # Handles setting relationship between a TaskGroup and another TaskGroup
302 if upstream:
303 parent, child = (self, other)
304 if edge_modifier:
305 edge_modifier.add_edge_info(self.dag, other.downstream_join_id, self.upstream_join_id)
306 else:
307 parent, child = (other, self)
308 if edge_modifier:
309 edge_modifier.add_edge_info(self.dag, self.downstream_join_id, other.upstream_join_id)
310
311 parent.upstream_group_ids.add(child.group_id)
312 child.downstream_group_ids.add(parent.group_id)
313 else:
314 # Handles setting relationship between a TaskGroup and a task
315 for task in other.roots:
316 if not isinstance(task, DAGNode):
317 raise RuntimeError(
318 "Relationships can only be set between TaskGroup "
319 f"or operators; received {task.__class__.__name__}"
320 )
321
322 # Do not set a relationship between a TaskGroup and a Label's roots
323 if self == task:
324 continue
325
326 if upstream:
327 self.upstream_task_ids.add(task.node_id)
328 if edge_modifier:
329 edge_modifier.add_edge_info(self.dag, task.node_id, self.upstream_join_id)
330 else:
331 self.downstream_task_ids.add(task.node_id)
332 if edge_modifier:
333 edge_modifier.add_edge_info(self.dag, self.downstream_join_id, task.node_id)
334
335 def _set_relatives(
336 self,
337 task_or_task_list: DependencyMixin | Sequence[DependencyMixin],
338 upstream: bool = False,
339 edge_modifier: EdgeModifier | None = None,
340 ) -> None:
341 """
342 Call set_upstream/set_downstream for all root/leaf tasks within this TaskGroup.
343
344 Update upstream_group_ids/downstream_group_ids/upstream_task_ids/downstream_task_ids.
345 """
346 if not isinstance(task_or_task_list, Sequence):
347 task_or_task_list = [task_or_task_list]
348
349 # Helper function to find leaves from a task list or task group
350 def find_leaves(group_or_task) -> list[Any]:
351 while group_or_task:
352 group_or_task_leaves = list(group_or_task.get_leaves())
353 if group_or_task_leaves:
354 return group_or_task_leaves
355 if group_or_task.upstream_task_ids:
356 upstream_task_ids_list = list(group_or_task.upstream_task_ids)
357 return [self.dag.get_task(task_id) for task_id in upstream_task_ids_list]
358 group_or_task = group_or_task.parent_group
359 return []
360
361 # Check if the current TaskGroup is empty
362 leaves = find_leaves(self)
363
364 for task_like in task_or_task_list:
365 self.update_relative(task_like, upstream, edge_modifier=edge_modifier)
366
367 if upstream:
368 for task in self.get_roots():
369 task.set_upstream(task_or_task_list)
370 else:
371 for task in leaves: # Use the fetched leaves
372 task.set_downstream(task_or_task_list)
373
374 def __enter__(self) -> TaskGroup:
375 from airflow.sdk.definitions._internal.contextmanager import TaskGroupContext
376
377 TaskGroupContext.push(self)
378 return self
379
380 def __exit__(self, _type, _value, _tb):
381 from airflow.sdk.definitions._internal.contextmanager import TaskGroupContext
382
383 TaskGroupContext.pop()
384
385 def has_task(self, task: BaseOperator) -> bool:
386 """Return True if this TaskGroup or its children TaskGroups contains the given task."""
387 if task.task_id in self.children:
388 return True
389
390 return any(child.has_task(task) for child in self.children.values() if isinstance(child, TaskGroup))
391
392 @property
393 def roots(self) -> list[BaseOperator]:
394 """Required by DependencyMixin."""
395 return list(self.get_roots())
396
397 @property
398 def leaves(self) -> list[BaseOperator]:
399 """Required by DependencyMixin."""
400 return list(self.get_leaves())
401
402 def get_roots(self) -> Generator[BaseOperator, None, None]:
403 """Return a generator of tasks with no upstream dependencies within the TaskGroup."""
404 tasks = list(self)
405 ids = {x.task_id for x in tasks}
406 for task in tasks:
407 if task.upstream_task_ids.isdisjoint(ids):
408 yield task
409
410 def get_leaves(self) -> Generator[BaseOperator, None, None]:
411 """Return a generator of tasks with no downstream dependencies within the TaskGroup."""
412 tasks = list(self)
413 ids = {x.task_id for x in tasks}
414
415 def has_non_teardown_downstream(task, exclude: str):
416 for down_task in task.downstream_list:
417 if down_task.task_id == exclude:
418 continue
419 if down_task.task_id not in ids:
420 continue
421 if not down_task.is_teardown:
422 return True
423 return False
424
425 def recurse_for_first_non_teardown(task):
426 for upstream_task in task.upstream_list:
427 if upstream_task.task_id not in ids:
428 # upstream task is not in task group
429 continue
430 elif upstream_task.is_teardown:
431 yield from recurse_for_first_non_teardown(upstream_task)
432 elif task.is_teardown and upstream_task.is_setup:
433 # don't go through the teardown-to-setup path
434 continue
435 # return unless upstream task already has non-teardown downstream in group
436 elif not has_non_teardown_downstream(upstream_task, exclude=task.task_id):
437 yield upstream_task
438
439 for task in tasks:
440 if task.downstream_task_ids.isdisjoint(ids):
441 if not task.is_teardown:
442 yield task
443 else:
444 yield from recurse_for_first_non_teardown(task)
445
446 def child_id(self, label):
447 """Prefix label with group_id if prefix_group_id is True. Otherwise return the label as-is."""
448 if self.prefix_group_id:
449 group_id = self.group_id
450 if group_id:
451 return f"{group_id}.{label}"
452
453 return label
454
455 @property
456 def upstream_join_id(self) -> str:
457 """
458 Creates a unique ID for upstream dependencies of this TaskGroup.
459
460 If this TaskGroup has immediate upstream TaskGroups or tasks, a proxy node called
461 upstream_join_id will be created in Graph view to join the outgoing edges from this
462 TaskGroup to reduce the total number of edges needed to be displayed.
463 """
464 return f"{self.group_id}.upstream_join_id"
465
466 @property
467 def downstream_join_id(self) -> str:
468 """
469 Creates a unique ID for downstream dependencies of this TaskGroup.
470
471 If this TaskGroup has immediate downstream TaskGroups or tasks, a proxy node called
472 downstream_join_id will be created in Graph view to join the outgoing edges from this
473 TaskGroup to reduce the total number of edges needed to be displayed.
474 """
475 return f"{self.group_id}.downstream_join_id"
476
477 def get_task_group_dict(self) -> dict[str, TaskGroup]:
478 """Return a flat dictionary of group_id: TaskGroup."""
479 task_group_map = {}
480
481 def build_map(task_group):
482 if not isinstance(task_group, TaskGroup):
483 return
484
485 task_group_map[task_group.group_id] = task_group
486
487 for child in task_group.children.values():
488 build_map(child)
489
490 build_map(self)
491 return task_group_map
492
493 def get_child_by_label(self, label: str) -> DAGNode:
494 """Get a child task/TaskGroup by its label (i.e. task_id/group_id without the group_id prefix)."""
495 return self.children[self.child_id(label)]
496
497 def serialize_for_task_group(self) -> tuple[DagAttributeTypes, Any]:
498 """Serialize task group; required by DagNode."""
499 from airflow.serialization.enums import DagAttributeTypes
500 from airflow.serialization.serialized_objects import TaskGroupSerialization
501
502 return (
503 DagAttributeTypes.TASK_GROUP,
504 TaskGroupSerialization.serialize_task_group(self),
505 )
506
507 def hierarchical_alphabetical_sort(self):
508 """
509 Sort children in hierarchical alphabetical order.
510
511 - groups in alphabetical order first
512 - tasks in alphabetical order after them.
513
514 :return: list of tasks in hierarchical alphabetical order
515 """
516 return sorted(
517 self.children.values(),
518 key=lambda node: (not isinstance(node, TaskGroup), node.node_id),
519 )
520
521 def topological_sort(self):
522 """
523 Sorts children in topographical order, such that a task comes after any of its upstream dependencies.
524
525 :return: list of tasks in topological order
526 """
527 # This uses a modified version of Kahn's Topological Sort algorithm to
528 # not have to pre-compute the "in-degree" of the nodes.
529 graph_unsorted = copy.copy(self.children)
530
531 graph_sorted: list[DAGNode] = []
532
533 # special case
534 if not self.children:
535 return graph_sorted
536
537 # Run until the unsorted graph is empty.
538 while graph_unsorted:
539 # Go through each of the node/edges pairs in the unsorted graph. If a set of edges doesn't contain
540 # any nodes that haven't been resolved, that is, that are still in the unsorted graph, remove the
541 # pair from the unsorted graph, and append it to the sorted graph. Note here that by using
542 # the values() method for iterating, a copy of the unsorted graph is used, allowing us to modify
543 # the unsorted graph as we move through it.
544 #
545 # We also keep a flag for checking that graph is acyclic, which is true if any nodes are resolved
546 # during each pass through the graph. If not, we need to exit as the graph therefore can't be
547 # sorted.
548 acyclic = False
549 for node in list(graph_unsorted.values()):
550 for edge in node.upstream_list:
551 if edge.node_id in graph_unsorted:
552 break
553 # Check for task's group is a child (or grand child) of this TG,
554 tg = edge.task_group
555 while tg:
556 if tg.node_id in graph_unsorted:
557 break
558 tg = tg.parent_group
559
560 if tg:
561 # We are already going to visit that TG
562 break
563 else:
564 acyclic = True
565 del graph_unsorted[node.node_id]
566 graph_sorted.append(node)
567
568 if not acyclic:
569 raise AirflowDagCycleException(f"A cyclic dependency occurred in dag: {self.dag_id}")
570
571 return graph_sorted
572
573 def iter_mapped_task_groups(self) -> Iterator[MappedTaskGroup]:
574 """
575 Return mapped task groups in the hierarchy.
576
577 Groups are returned from the closest to the outmost. If *self* is a
578 mapped task group, it is returned first.
579
580 :meta private:
581 """
582 group: TaskGroup | None = self
583 while group is not None:
584 if isinstance(group, MappedTaskGroup):
585 yield group
586 group = group.parent_group
587
588 def iter_tasks(self) -> Iterator[AbstractOperator]:
589 """Return an iterator of the child tasks."""
590 from airflow.sdk.definitions._internal.abstractoperator import AbstractOperator
591
592 groups_to_visit = [self]
593
594 while groups_to_visit:
595 visiting = groups_to_visit.pop(0)
596
597 for child in visiting.children.values():
598 if isinstance(child, AbstractOperator):
599 yield child
600 elif isinstance(child, TaskGroup):
601 groups_to_visit.append(child)
602 else:
603 raise ValueError(
604 f"Encountered a DAGNode that is not a TaskGroup or an "
605 f"AbstractOperator: {type(child).__module__}.{type(child)}"
606 )
607
608
609@attrs.define(kw_only=True, repr=False)
610class MappedTaskGroup(TaskGroup):
611 """
612 A mapped task group.
613
614 This doesn't really do anything special, just holds some additional metadata
615 for expansion later.
616
617 Don't instantiate this class directly; call *expand* or *expand_kwargs* on
618 a ``@task_group`` function instead.
619 """
620
621 _expand_input: DictOfListsExpandInput | ListOfDictsExpandInput = attrs.field(alias="expand_input")
622
623 def __iter__(self):
624 for child in self.children.values():
625 if getattr(child, "trigger_rule", None) == TriggerRule.ALWAYS:
626 raise ValueError(
627 "Task-generated mapping within a mapped task group is not "
628 "allowed with trigger rule 'always'"
629 )
630 yield from self._iter_child(child)
631
632 @methodtools.lru_cache(maxsize=None)
633 def get_parse_time_mapped_ti_count(self) -> int:
634 """
635 Return the Number of instances a task in this group should be mapped to, when a Dag run is created.
636
637 This only considers literal mapped arguments, and would return *None*
638 when any non-literal values are used for mapping.
639
640 If this group is inside mapped task groups, all the nested counts are
641 multiplied and accounted.
642
643 :meta private:
644
645 :raise NotFullyPopulated: If any non-literal mapped arguments are encountered.
646 :return: The total number of mapped instances each task should have.
647 """
648 return functools.reduce(
649 operator.mul,
650 (g._expand_input.get_parse_time_mapped_ti_count() for g in self.iter_mapped_task_groups()),
651 )
652
653 def __exit__(self, exc_type, exc_val, exc_tb):
654 for op, _ in self._expand_input.iter_references():
655 self.set_upstream(op)
656 super().__exit__(exc_type, exc_val, exc_tb)
657
658 def iter_mapped_dependencies(self) -> Iterator[Operator]:
659 """Upstream dependencies that provide XComs used by this mapped task group."""
660 from airflow.sdk.definitions.xcom_arg import XComArg
661
662 for op, _ in XComArg.iter_xcom_references(self._expand_input):
663 yield op