1#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18"""A collection of closely related tasks on the same DAG that should be grouped together visually."""
19
20from __future__ import annotations
21
22import copy
23import functools
24import operator
25import weakref
26from typing import TYPE_CHECKING, Any, Generator, Iterator, Sequence
27
28import methodtools
29import re2
30
31from airflow.exceptions import (
32 AirflowDagCycleException,
33 AirflowException,
34 DuplicateTaskIdFound,
35 TaskAlreadyInTaskGroup,
36)
37from airflow.models.taskmixin import DAGNode
38from airflow.serialization.enums import DagAttributeTypes
39from airflow.utils.helpers import validate_group_key
40
41if TYPE_CHECKING:
42 from sqlalchemy.orm import Session
43
44 from airflow.models.abstractoperator import AbstractOperator
45 from airflow.models.baseoperator import BaseOperator
46 from airflow.models.dag import DAG
47 from airflow.models.expandinput import ExpandInput
48 from airflow.models.operator import Operator
49 from airflow.models.taskmixin import DependencyMixin
50 from airflow.utils.edgemodifier import EdgeModifier
51
52
53class TaskGroup(DAGNode):
54 """
55 A collection of tasks.
56
57 When set_downstream() or set_upstream() are called on the TaskGroup, it is applied across
58 all tasks within the group if necessary.
59
60 :param group_id: a unique, meaningful id for the TaskGroup. group_id must not conflict
61 with group_id of TaskGroup or task_id of tasks in the DAG. Root TaskGroup has group_id
62 set to None.
63 :param prefix_group_id: If set to True, child task_id and group_id will be prefixed with
64 this TaskGroup's group_id. If set to False, child task_id and group_id are not prefixed.
65 Default is True.
66 :param parent_group: The parent TaskGroup of this TaskGroup. parent_group is set to None
67 for the root TaskGroup.
68 :param dag: The DAG that this TaskGroup belongs to.
69 :param default_args: A dictionary of default parameters to be used
70 as constructor keyword parameters when initialising operators,
71 will override default_args defined in the DAG level.
72 Note that operators have the same hook, and precede those defined
73 here, meaning that if your dict contains `'depends_on_past': True`
74 here and `'depends_on_past': False` in the operator's call
75 `default_args`, the actual value will be `False`.
76 :param tooltip: The tooltip of the TaskGroup node when displayed in the UI
77 :param ui_color: The fill color of the TaskGroup node when displayed in the UI
78 :param ui_fgcolor: The label color of the TaskGroup node when displayed in the UI
79 :param add_suffix_on_collision: If this task group name already exists,
80 automatically add `__1` etc suffixes
81 """
82
83 used_group_ids: set[str | None]
84
85 def __init__(
86 self,
87 group_id: str | None,
88 prefix_group_id: bool = True,
89 parent_group: TaskGroup | None = None,
90 dag: DAG | None = None,
91 default_args: dict[str, Any] | None = None,
92 tooltip: str = "",
93 ui_color: str = "CornflowerBlue",
94 ui_fgcolor: str = "#000",
95 add_suffix_on_collision: bool = False,
96 ):
97 from airflow.models.dag import DagContext
98
99 self.prefix_group_id = prefix_group_id
100 self.default_args = copy.deepcopy(default_args or {})
101
102 dag = dag or DagContext.get_current_dag()
103
104 if group_id is None:
105 # This creates a root TaskGroup.
106 if parent_group:
107 raise AirflowException("Root TaskGroup cannot have parent_group")
108 # used_group_ids is shared across all TaskGroups in the same DAG to keep track
109 # of used group_id to avoid duplication.
110 self.used_group_ids = set()
111 self.dag = dag
112 else:
113 if prefix_group_id:
114 # If group id is used as prefix, it should not contain spaces nor dots
115 # because it is used as prefix in the task_id
116 validate_group_key(group_id)
117 else:
118 if not isinstance(group_id, str):
119 raise ValueError("group_id must be str")
120 if not group_id:
121 raise ValueError("group_id must not be empty")
122
123 if not parent_group and not dag:
124 raise AirflowException("TaskGroup can only be used inside a dag")
125
126 parent_group = parent_group or TaskGroupContext.get_current_task_group(dag)
127 if not parent_group:
128 raise AirflowException("TaskGroup must have a parent_group except for the root TaskGroup")
129 if dag is not parent_group.dag:
130 raise RuntimeError(
131 "Cannot mix TaskGroups from different DAGs: %s and %s", dag, parent_group.dag
132 )
133
134 self.used_group_ids = parent_group.used_group_ids
135
136 # if given group_id already used assign suffix by incrementing largest used suffix integer
137 # Example : task_group ==> task_group__1 -> task_group__2 -> task_group__3
138 self._group_id = group_id
139 self._check_for_group_id_collisions(add_suffix_on_collision)
140
141 self.children: dict[str, DAGNode] = {}
142
143 if parent_group:
144 parent_group.add(self)
145 self._update_default_args(parent_group)
146
147 self.used_group_ids.add(self.group_id)
148 if self.group_id:
149 self.used_group_ids.add(self.downstream_join_id)
150 self.used_group_ids.add(self.upstream_join_id)
151
152 self.tooltip = tooltip
153 self.ui_color = ui_color
154 self.ui_fgcolor = ui_fgcolor
155
156 # Keep track of TaskGroups or tasks that depend on this entire TaskGroup separately
157 # so that we can optimize the number of edges when entire TaskGroups depend on each other.
158 self.upstream_group_ids: set[str | None] = set()
159 self.downstream_group_ids: set[str | None] = set()
160 self.upstream_task_ids = set()
161 self.downstream_task_ids = set()
162
163 def _check_for_group_id_collisions(self, add_suffix_on_collision: bool):
164 if self._group_id is None:
165 return
166 # if given group_id already used assign suffix by incrementing largest used suffix integer
167 # Example : task_group ==> task_group__1 -> task_group__2 -> task_group__3
168 if self._group_id in self.used_group_ids:
169 if not add_suffix_on_collision:
170 raise DuplicateTaskIdFound(f"group_id '{self._group_id}' has already been added to the DAG")
171 base = re2.split(r"__\d+$", self._group_id)[0]
172 suffixes = sorted(
173 int(re2.split(r"^.+__", used_group_id)[1])
174 for used_group_id in self.used_group_ids
175 if used_group_id is not None and re2.match(rf"^{base}__\d+$", used_group_id)
176 )
177 if not suffixes:
178 self._group_id += "__1"
179 else:
180 self._group_id = f"{base}__{suffixes[-1] + 1}"
181
182 def _update_default_args(self, parent_group: TaskGroup):
183 if parent_group.default_args:
184 self.default_args = {**parent_group.default_args, **self.default_args}
185
186 @classmethod
187 def create_root(cls, dag: DAG) -> TaskGroup:
188 """Create a root TaskGroup with no group_id or parent."""
189 return cls(group_id=None, dag=dag)
190
191 @property
192 def node_id(self):
193 return self.group_id
194
195 @property
196 def is_root(self) -> bool:
197 """Returns True if this TaskGroup is the root TaskGroup. Otherwise False."""
198 return not self.group_id
199
200 @property
201 def parent_group(self) -> TaskGroup | None:
202 return self.task_group
203
204 def __iter__(self):
205 for child in self.children.values():
206 if isinstance(child, TaskGroup):
207 yield from child
208 else:
209 yield child
210
211 def add(self, task: DAGNode) -> DAGNode:
212 """Add a task to this TaskGroup.
213
214 :meta private:
215 """
216 from airflow.models.abstractoperator import AbstractOperator
217
218 if TaskGroupContext.active:
219 if task.task_group and task.task_group != self:
220 task.task_group.children.pop(task.node_id, None)
221 task.task_group = self
222 existing_tg = task.task_group
223 if isinstance(task, AbstractOperator) and existing_tg is not None and existing_tg != self:
224 raise TaskAlreadyInTaskGroup(task.node_id, existing_tg.node_id, self.node_id)
225
226 # Set the TG first, as setting it might change the return value of node_id!
227 task.task_group = weakref.proxy(self)
228 key = task.node_id
229
230 if key in self.children:
231 node_type = "Task" if hasattr(task, "task_id") else "Task Group"
232 raise DuplicateTaskIdFound(f"{node_type} id '{key}' has already been added to the DAG")
233
234 if isinstance(task, TaskGroup):
235 if self.dag:
236 if task.dag is not None and self.dag is not task.dag:
237 raise RuntimeError(
238 "Cannot mix TaskGroups from different DAGs: %s and %s", self.dag, task.dag
239 )
240 task.dag = self.dag
241 if task.children:
242 raise AirflowException("Cannot add a non-empty TaskGroup")
243
244 self.children[key] = task
245 return task
246
247 def _remove(self, task: DAGNode) -> None:
248 key = task.node_id
249
250 if key not in self.children:
251 raise KeyError(f"Node id {key!r} not part of this task group")
252
253 self.used_group_ids.remove(key)
254 del self.children[key]
255
256 @property
257 def group_id(self) -> str | None:
258 """group_id of this TaskGroup."""
259 if self.task_group and self.task_group.prefix_group_id and self.task_group._group_id:
260 # defer to parent whether it adds a prefix
261 return self.task_group.child_id(self._group_id)
262
263 return self._group_id
264
265 @property
266 def label(self) -> str | None:
267 """group_id excluding parent's group_id used as the node label in UI."""
268 return self._group_id
269
270 def update_relative(
271 self, other: DependencyMixin, upstream: bool = True, edge_modifier: EdgeModifier | None = None
272 ) -> None:
273 """
274 Override TaskMixin.update_relative.
275
276 Update upstream_group_ids/downstream_group_ids/upstream_task_ids/downstream_task_ids
277 accordingly so that we can reduce the number of edges when displaying Graph view.
278 """
279 if isinstance(other, TaskGroup):
280 # Handles setting relationship between a TaskGroup and another TaskGroup
281 if upstream:
282 parent, child = (self, other)
283 if edge_modifier:
284 edge_modifier.add_edge_info(self.dag, other.downstream_join_id, self.upstream_join_id)
285 else:
286 parent, child = (other, self)
287 if edge_modifier:
288 edge_modifier.add_edge_info(self.dag, self.downstream_join_id, other.upstream_join_id)
289
290 parent.upstream_group_ids.add(child.group_id)
291 child.downstream_group_ids.add(parent.group_id)
292 else:
293 # Handles setting relationship between a TaskGroup and a task
294 for task in other.roots:
295 if not isinstance(task, DAGNode):
296 raise AirflowException(
297 "Relationships can only be set between TaskGroup "
298 f"or operators; received {task.__class__.__name__}"
299 )
300
301 # Do not set a relationship between a TaskGroup and a Label's roots
302 if self == task:
303 continue
304
305 if upstream:
306 self.upstream_task_ids.add(task.node_id)
307 if edge_modifier:
308 edge_modifier.add_edge_info(self.dag, task.node_id, self.upstream_join_id)
309 else:
310 self.downstream_task_ids.add(task.node_id)
311 if edge_modifier:
312 edge_modifier.add_edge_info(self.dag, self.downstream_join_id, task.node_id)
313
314 def _set_relatives(
315 self,
316 task_or_task_list: DependencyMixin | Sequence[DependencyMixin],
317 upstream: bool = False,
318 edge_modifier: EdgeModifier | None = None,
319 ) -> None:
320 """
321 Call set_upstream/set_downstream for all root/leaf tasks within this TaskGroup.
322
323 Update upstream_group_ids/downstream_group_ids/upstream_task_ids/downstream_task_ids.
324 """
325 if not isinstance(task_or_task_list, Sequence):
326 task_or_task_list = [task_or_task_list]
327
328 for task_like in task_or_task_list:
329 self.update_relative(task_like, upstream, edge_modifier=edge_modifier)
330
331 if upstream:
332 for task in self.get_roots():
333 task.set_upstream(task_or_task_list)
334 else:
335 for task in self.get_leaves():
336 task.set_downstream(task_or_task_list)
337
338 def __enter__(self) -> TaskGroup:
339 TaskGroupContext.push_context_managed_task_group(self)
340 return self
341
342 def __exit__(self, _type, _value, _tb):
343 TaskGroupContext.pop_context_managed_task_group()
344
345 def has_task(self, task: BaseOperator) -> bool:
346 """Return True if this TaskGroup or its children TaskGroups contains the given task."""
347 if task.task_id in self.children:
348 return True
349
350 return any(child.has_task(task) for child in self.children.values() if isinstance(child, TaskGroup))
351
352 @property
353 def roots(self) -> list[BaseOperator]:
354 """Required by TaskMixin."""
355 return list(self.get_roots())
356
357 @property
358 def leaves(self) -> list[BaseOperator]:
359 """Required by TaskMixin."""
360 return list(self.get_leaves())
361
362 def get_roots(self) -> Generator[BaseOperator, None, None]:
363 """Return a generator of tasks with no upstream dependencies within the TaskGroup."""
364 tasks = list(self)
365 ids = {x.task_id for x in tasks}
366 for task in tasks:
367 if task.upstream_task_ids.isdisjoint(ids):
368 yield task
369
370 def get_leaves(self) -> Generator[BaseOperator, None, None]:
371 """Return a generator of tasks with no downstream dependencies within the TaskGroup."""
372 tasks = list(self)
373 ids = {x.task_id for x in tasks}
374
375 def has_non_teardown_downstream(task, exclude: str):
376 for down_task in task.downstream_list:
377 if down_task.task_id == exclude:
378 continue
379 elif down_task.task_id not in ids:
380 continue
381 elif not down_task.is_teardown:
382 return True
383 return False
384
385 def recurse_for_first_non_teardown(task):
386 for upstream_task in task.upstream_list:
387 if upstream_task.task_id not in ids:
388 # upstream task is not in task group
389 continue
390 elif upstream_task.is_teardown:
391 yield from recurse_for_first_non_teardown(upstream_task)
392 elif task.is_teardown and upstream_task.is_setup:
393 # don't go through the teardown-to-setup path
394 continue
395 # return unless upstream task already has non-teardown downstream in group
396 elif not has_non_teardown_downstream(upstream_task, exclude=task.task_id):
397 yield upstream_task
398
399 for task in tasks:
400 if task.downstream_task_ids.isdisjoint(ids):
401 if not task.is_teardown:
402 yield task
403 else:
404 yield from recurse_for_first_non_teardown(task)
405
406 def child_id(self, label):
407 """Prefix label with group_id if prefix_group_id is True. Otherwise return the label as-is."""
408 if self.prefix_group_id:
409 group_id = self.group_id
410 if group_id:
411 return f"{group_id}.{label}"
412
413 return label
414
415 @property
416 def upstream_join_id(self) -> str:
417 """
418 Creates a unique ID for upstream dependencies of this TaskGroup.
419
420 If this TaskGroup has immediate upstream TaskGroups or tasks, a proxy node called
421 upstream_join_id will be created in Graph view to join the outgoing edges from this
422 TaskGroup to reduce the total number of edges needed to be displayed.
423 """
424 return f"{self.group_id}.upstream_join_id"
425
426 @property
427 def downstream_join_id(self) -> str:
428 """
429 Creates a unique ID for downstream dependencies of this TaskGroup.
430
431 If this TaskGroup has immediate downstream TaskGroups or tasks, a proxy node called
432 downstream_join_id will be created in Graph view to join the outgoing edges from this
433 TaskGroup to reduce the total number of edges needed to be displayed.
434 """
435 return f"{self.group_id}.downstream_join_id"
436
437 def get_task_group_dict(self) -> dict[str, TaskGroup]:
438 """Return a flat dictionary of group_id: TaskGroup."""
439 task_group_map = {}
440
441 def build_map(task_group):
442 if not isinstance(task_group, TaskGroup):
443 return
444
445 task_group_map[task_group.group_id] = task_group
446
447 for child in task_group.children.values():
448 build_map(child)
449
450 build_map(self)
451 return task_group_map
452
453 def get_child_by_label(self, label: str) -> DAGNode:
454 """Get a child task/TaskGroup by its label (i.e. task_id/group_id without the group_id prefix)."""
455 return self.children[self.child_id(label)]
456
457 def serialize_for_task_group(self) -> tuple[DagAttributeTypes, Any]:
458 """Serialize task group; required by DAGNode."""
459 from airflow.serialization.serialized_objects import TaskGroupSerialization
460
461 return DagAttributeTypes.TASK_GROUP, TaskGroupSerialization.serialize_task_group(self)
462
463 def hierarchical_alphabetical_sort(self):
464 """
465 Sort children in hierarchical alphabetical order.
466
467 - groups in alphabetical order first
468 - tasks in alphabetical order after them.
469
470 :return: list of tasks in hierarchical alphabetical order
471 """
472 return sorted(
473 self.children.values(), key=lambda node: (not isinstance(node, TaskGroup), node.node_id)
474 )
475
476 def topological_sort(self, _include_subdag_tasks: bool = False):
477 """
478 Sorts children in topographical order, such that a task comes after any of its upstream dependencies.
479
480 :return: list of tasks in topological order
481 """
482 # This uses a modified version of Kahn's Topological Sort algorithm to
483 # not have to pre-compute the "in-degree" of the nodes.
484 from airflow.operators.subdag import SubDagOperator # Avoid circular import
485
486 graph_unsorted = copy.copy(self.children)
487
488 graph_sorted: list[DAGNode] = []
489
490 # special case
491 if not self.children:
492 return graph_sorted
493
494 # Run until the unsorted graph is empty.
495 while graph_unsorted:
496 # Go through each of the node/edges pairs in the unsorted graph. If a set of edges doesn't contain
497 # any nodes that haven't been resolved, that is, that are still in the unsorted graph, remove the
498 # pair from the unsorted graph, and append it to the sorted graph. Note here that by using
499 # the values() method for iterating, a copy of the unsorted graph is used, allowing us to modify
500 # the unsorted graph as we move through it.
501 #
502 # We also keep a flag for checking that graph is acyclic, which is true if any nodes are resolved
503 # during each pass through the graph. If not, we need to exit as the graph therefore can't be
504 # sorted.
505 acyclic = False
506 for node in list(graph_unsorted.values()):
507 for edge in node.upstream_list:
508 if edge.node_id in graph_unsorted:
509 break
510 # Check for task's group is a child (or grand child) of this TG,
511 tg = edge.task_group
512 while tg:
513 if tg.node_id in graph_unsorted:
514 break
515 tg = tg.task_group
516
517 if tg:
518 # We are already going to visit that TG
519 break
520 else:
521 acyclic = True
522 del graph_unsorted[node.node_id]
523 graph_sorted.append(node)
524 if _include_subdag_tasks and isinstance(node, SubDagOperator):
525 graph_sorted.extend(
526 node.subdag.task_group.topological_sort(_include_subdag_tasks=True)
527 )
528
529 if not acyclic:
530 raise AirflowDagCycleException(f"A cyclic dependency occurred in dag: {self.dag_id}")
531
532 return graph_sorted
533
534 def iter_mapped_task_groups(self) -> Iterator[MappedTaskGroup]:
535 """Return mapped task groups in the hierarchy.
536
537 Groups are returned from the closest to the outmost. If *self* is a
538 mapped task group, it is returned first.
539
540 :meta private:
541 """
542 group: TaskGroup | None = self
543 while group is not None:
544 if isinstance(group, MappedTaskGroup):
545 yield group
546 group = group.task_group
547
548 def iter_tasks(self) -> Iterator[AbstractOperator]:
549 """Return an iterator of the child tasks."""
550 from airflow.models.abstractoperator import AbstractOperator
551
552 groups_to_visit = [self]
553
554 while groups_to_visit:
555 visiting = groups_to_visit.pop(0)
556
557 for child in visiting.children.values():
558 if isinstance(child, AbstractOperator):
559 yield child
560 elif isinstance(child, TaskGroup):
561 groups_to_visit.append(child)
562 else:
563 raise ValueError(
564 f"Encountered a DAGNode that is not a TaskGroup or an AbstractOperator: {type(child)}"
565 )
566
567
568class MappedTaskGroup(TaskGroup):
569 """A mapped task group.
570
571 This doesn't really do anything special, just holds some additional metadata
572 for expansion later.
573
574 Don't instantiate this class directly; call *expand* or *expand_kwargs* on
575 a ``@task_group`` function instead.
576 """
577
578 def __init__(self, *, expand_input: ExpandInput, **kwargs: Any) -> None:
579 super().__init__(**kwargs)
580 self._expand_input = expand_input
581
582 def iter_mapped_dependencies(self) -> Iterator[Operator]:
583 """Upstream dependencies that provide XComs used by this mapped task group."""
584 from airflow.models.xcom_arg import XComArg
585
586 for op, _ in XComArg.iter_xcom_references(self._expand_input):
587 yield op
588
589 @methodtools.lru_cache(maxsize=None)
590 def get_parse_time_mapped_ti_count(self) -> int:
591 """
592 Return the Number of instances a task in this group should be mapped to, when a DAG run is created.
593
594 This only considers literal mapped arguments, and would return *None*
595 when any non-literal values are used for mapping.
596
597 If this group is inside mapped task groups, all the nested counts are
598 multiplied and accounted.
599
600 :meta private:
601
602 :raise NotFullyPopulated: If any non-literal mapped arguments are encountered.
603 :return: The total number of mapped instances each task should have.
604 """
605 return functools.reduce(
606 operator.mul,
607 (g._expand_input.get_parse_time_mapped_ti_count() for g in self.iter_mapped_task_groups()),
608 )
609
610 def get_mapped_ti_count(self, run_id: str, *, session: Session) -> int:
611 """
612 Return the number of instances a task in this group should be mapped to at run time.
613
614 This considers both literal and non-literal mapped arguments, and the
615 result is therefore available when all depended tasks have finished. The
616 return value should be identical to ``parse_time_mapped_ti_count`` if
617 all mapped arguments are literal.
618
619 If this group is inside mapped task groups, all the nested counts are
620 multiplied and accounted.
621
622 :meta private:
623
624 :raise NotFullyPopulated: If upstream tasks are not all complete yet.
625 :return: Total number of mapped TIs this task should have.
626 """
627 groups = self.iter_mapped_task_groups()
628 return functools.reduce(
629 operator.mul,
630 (g._expand_input.get_total_map_length(run_id, session=session) for g in groups),
631 )
632
633 def __exit__(self, exc_type, exc_val, exc_tb):
634 for op, _ in self._expand_input.iter_references():
635 self.set_upstream(op)
636 super().__exit__(exc_type, exc_val, exc_tb)
637
638
639class TaskGroupContext:
640 """TaskGroup context is used to keep the current TaskGroup when TaskGroup is used as ContextManager."""
641
642 active: bool = False
643 _context_managed_task_group: TaskGroup | None = None
644 _previous_context_managed_task_groups: list[TaskGroup] = []
645
646 @classmethod
647 def push_context_managed_task_group(cls, task_group: TaskGroup):
648 """Push a TaskGroup into the list of managed TaskGroups."""
649 if cls._context_managed_task_group:
650 cls._previous_context_managed_task_groups.append(cls._context_managed_task_group)
651 cls._context_managed_task_group = task_group
652 cls.active = True
653
654 @classmethod
655 def pop_context_managed_task_group(cls) -> TaskGroup | None:
656 """Pops the last TaskGroup from the list of managed TaskGroups and update the current TaskGroup."""
657 old_task_group = cls._context_managed_task_group
658 if cls._previous_context_managed_task_groups:
659 cls._context_managed_task_group = cls._previous_context_managed_task_groups.pop()
660 else:
661 cls._context_managed_task_group = None
662 cls.active = False
663 return old_task_group
664
665 @classmethod
666 def get_current_task_group(cls, dag: DAG | None) -> TaskGroup | None:
667 """Get the current TaskGroup."""
668 from airflow.models.dag import DagContext
669
670 if not cls._context_managed_task_group:
671 dag = dag or DagContext.get_current_dag()
672 if dag:
673 # If there's currently a DAG but no TaskGroup, return the root TaskGroup of the dag.
674 return dag.task_group
675
676 return cls._context_managed_task_group
677
678
679def task_group_to_dict(task_item_or_group):
680 """Create a nested dict representation of this TaskGroup and its children used to construct the Graph."""
681 from airflow.models.abstractoperator import AbstractOperator
682 from airflow.models.mappedoperator import MappedOperator
683
684 if isinstance(task := task_item_or_group, AbstractOperator):
685 setup_teardown_type = {}
686 is_mapped = {}
687 if task.is_setup is True:
688 setup_teardown_type["setupTeardownType"] = "setup"
689 elif task.is_teardown is True:
690 setup_teardown_type["setupTeardownType"] = "teardown"
691 if isinstance(task, MappedOperator):
692 is_mapped["isMapped"] = True
693 return {
694 "id": task.task_id,
695 "value": {
696 "label": task.label,
697 "labelStyle": f"fill:{task.ui_fgcolor};",
698 "style": f"fill:{task.ui_color};",
699 "rx": 5,
700 "ry": 5,
701 **is_mapped,
702 **setup_teardown_type,
703 },
704 }
705 task_group = task_item_or_group
706 is_mapped = isinstance(task_group, MappedTaskGroup)
707 children = [
708 task_group_to_dict(child) for child in sorted(task_group.children.values(), key=lambda t: t.label)
709 ]
710
711 if task_group.upstream_group_ids or task_group.upstream_task_ids:
712 children.append(
713 {
714 "id": task_group.upstream_join_id,
715 "value": {
716 "label": "",
717 "labelStyle": f"fill:{task_group.ui_fgcolor};",
718 "style": f"fill:{task_group.ui_color};",
719 "shape": "circle",
720 },
721 }
722 )
723
724 if task_group.downstream_group_ids or task_group.downstream_task_ids:
725 # This is the join node used to reduce the number of edges between two TaskGroup.
726 children.append(
727 {
728 "id": task_group.downstream_join_id,
729 "value": {
730 "label": "",
731 "labelStyle": f"fill:{task_group.ui_fgcolor};",
732 "style": f"fill:{task_group.ui_color};",
733 "shape": "circle",
734 },
735 }
736 )
737
738 return {
739 "id": task_group.group_id,
740 "value": {
741 "label": task_group.label,
742 "labelStyle": f"fill:{task_group.ui_fgcolor};",
743 "style": f"fill:{task_group.ui_color}",
744 "rx": 5,
745 "ry": 5,
746 "clusterLabelPos": "top",
747 "tooltip": task_group.tooltip,
748 "isMapped": is_mapped,
749 },
750 "children": children,
751 }