Coverage for /pythoncovmergedfiles/medio/medio/src/airflow/build/lib/airflow/utils/task_group.py: 38%
305 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-07 06:35 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-07 06:35 +0000
1#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18"""
19A TaskGroup is a collection of closely related tasks on the same DAG that should be grouped
20together when the DAG is displayed graphically.
21"""
22from __future__ import annotations
24import copy
25import functools
26import operator
27import re
28import weakref
29from typing import TYPE_CHECKING, Any, Generator, Iterator, Sequence
31from airflow.compat.functools import cache
32from airflow.exceptions import (
33 AirflowDagCycleException,
34 AirflowException,
35 DuplicateTaskIdFound,
36 TaskAlreadyInTaskGroup,
37)
38from airflow.models.taskmixin import DAGNode, DependencyMixin
39from airflow.serialization.enums import DagAttributeTypes
40from airflow.utils.helpers import validate_group_key
42if TYPE_CHECKING:
43 from sqlalchemy.orm import Session
45 from airflow.models.abstractoperator import AbstractOperator
46 from airflow.models.baseoperator import BaseOperator
47 from airflow.models.dag import DAG
48 from airflow.models.expandinput import ExpandInput
49 from airflow.models.operator import Operator
50 from airflow.utils.edgemodifier import EdgeModifier
53class TaskGroup(DAGNode):
54 """
55 A collection of tasks. When set_downstream() or set_upstream() are called on the
56 TaskGroup, it is applied across all tasks within the group if necessary.
58 :param group_id: a unique, meaningful id for the TaskGroup. group_id must not conflict
59 with group_id of TaskGroup or task_id of tasks in the DAG. Root TaskGroup has group_id
60 set to None.
61 :param prefix_group_id: If set to True, child task_id and group_id will be prefixed with
62 this TaskGroup's group_id. If set to False, child task_id and group_id are not prefixed.
63 Default is True.
64 :param parent_group: The parent TaskGroup of this TaskGroup. parent_group is set to None
65 for the root TaskGroup.
66 :param dag: The DAG that this TaskGroup belongs to.
67 :param default_args: A dictionary of default parameters to be used
68 as constructor keyword parameters when initialising operators,
69 will override default_args defined in the DAG level.
70 Note that operators have the same hook, and precede those defined
71 here, meaning that if your dict contains `'depends_on_past': True`
72 here and `'depends_on_past': False` in the operator's call
73 `default_args`, the actual value will be `False`.
74 :param tooltip: The tooltip of the TaskGroup node when displayed in the UI
75 :param ui_color: The fill color of the TaskGroup node when displayed in the UI
76 :param ui_fgcolor: The label color of the TaskGroup node when displayed in the UI
77 :param add_suffix_on_collision: If this task group name already exists,
78 automatically add `__1` etc suffixes
79 """
81 used_group_ids: set[str | None]
83 def __init__(
84 self,
85 group_id: str | None,
86 prefix_group_id: bool = True,
87 parent_group: TaskGroup | None = None,
88 dag: DAG | None = None,
89 default_args: dict[str, Any] | None = None,
90 tooltip: str = "",
91 ui_color: str = "CornflowerBlue",
92 ui_fgcolor: str = "#000",
93 add_suffix_on_collision: bool = False,
94 ):
95 from airflow.models.dag import DagContext
97 self.prefix_group_id = prefix_group_id
98 self.default_args = copy.deepcopy(default_args or {})
100 dag = dag or DagContext.get_current_dag()
102 if group_id is None:
103 # This creates a root TaskGroup.
104 if parent_group:
105 raise AirflowException("Root TaskGroup cannot have parent_group")
106 # used_group_ids is shared across all TaskGroups in the same DAG to keep track
107 # of used group_id to avoid duplication.
108 self.used_group_ids = set()
109 self.dag = dag
110 else:
111 if prefix_group_id:
112 # If group id is used as prefix, it should not contain spaces nor dots
113 # because it is used as prefix in the task_id
114 validate_group_key(group_id)
115 else:
116 if not isinstance(group_id, str):
117 raise ValueError("group_id must be str")
118 if not group_id:
119 raise ValueError("group_id must not be empty")
121 if not parent_group and not dag:
122 raise AirflowException("TaskGroup can only be used inside a dag")
124 parent_group = parent_group or TaskGroupContext.get_current_task_group(dag)
125 if not parent_group:
126 raise AirflowException("TaskGroup must have a parent_group except for the root TaskGroup")
127 if dag is not parent_group.dag:
128 raise RuntimeError(
129 "Cannot mix TaskGroups from different DAGs: %s and %s", dag, parent_group.dag
130 )
132 self.used_group_ids = parent_group.used_group_ids
134 # if given group_id already used assign suffix by incrementing largest used suffix integer
135 # Example : task_group ==> task_group__1 -> task_group__2 -> task_group__3
136 self._group_id = group_id
137 self._check_for_group_id_collisions(add_suffix_on_collision)
139 self.children: dict[str, DAGNode] = {}
141 if parent_group:
142 parent_group.add(self)
143 self._update_default_args(parent_group)
145 self.used_group_ids.add(self.group_id)
146 if self.group_id:
147 self.used_group_ids.add(self.downstream_join_id)
148 self.used_group_ids.add(self.upstream_join_id)
150 self.tooltip = tooltip
151 self.ui_color = ui_color
152 self.ui_fgcolor = ui_fgcolor
154 # Keep track of TaskGroups or tasks that depend on this entire TaskGroup separately
155 # so that we can optimize the number of edges when entire TaskGroups depend on each other.
156 self.upstream_group_ids: set[str | None] = set()
157 self.downstream_group_ids: set[str | None] = set()
158 self.upstream_task_ids = set()
159 self.downstream_task_ids = set()
161 def _check_for_group_id_collisions(self, add_suffix_on_collision: bool):
162 if self._group_id is None:
163 return
164 # if given group_id already used assign suffix by incrementing largest used suffix integer
165 # Example : task_group ==> task_group__1 -> task_group__2 -> task_group__3
166 if self._group_id in self.used_group_ids:
167 if not add_suffix_on_collision:
168 raise DuplicateTaskIdFound(f"group_id '{self._group_id}' has already been added to the DAG")
169 base = re.split(r"__\d+$", self._group_id)[0]
170 suffixes = sorted(
171 int(re.split(r"^.+__", used_group_id)[1])
172 for used_group_id in self.used_group_ids
173 if used_group_id is not None and re.match(rf"^{base}__\d+$", used_group_id)
174 )
175 if not suffixes:
176 self._group_id += "__1"
177 else:
178 self._group_id = f"{base}__{suffixes[-1] + 1}"
180 def _update_default_args(self, parent_group: TaskGroup):
181 if parent_group.default_args:
182 self.default_args = {**self.default_args, **parent_group.default_args}
184 @classmethod
185 def create_root(cls, dag: DAG) -> TaskGroup:
186 """Create a root TaskGroup with no group_id or parent."""
187 return cls(group_id=None, dag=dag)
189 @property
190 def node_id(self):
191 return self.group_id
193 @property
194 def is_root(self) -> bool:
195 """Returns True if this TaskGroup is the root TaskGroup. Otherwise False."""
196 return not self.group_id
198 @property
199 def parent_group(self) -> TaskGroup | None:
200 return self.task_group
202 def __iter__(self):
203 for child in self.children.values():
204 if isinstance(child, TaskGroup):
205 yield from child
206 else:
207 yield child
209 def add(self, task: DAGNode) -> None:
210 """Add a task to this TaskGroup.
212 :meta private:
213 """
214 from airflow.models.abstractoperator import AbstractOperator
216 existing_tg = task.task_group
217 if isinstance(task, AbstractOperator) and existing_tg is not None and existing_tg != self:
218 raise TaskAlreadyInTaskGroup(task.node_id, existing_tg.node_id, self.node_id)
220 # Set the TG first, as setting it might change the return value of node_id!
221 task.task_group = weakref.proxy(self)
222 key = task.node_id
224 if key in self.children:
225 node_type = "Task" if hasattr(task, "task_id") else "Task Group"
226 raise DuplicateTaskIdFound(f"{node_type} id '{key}' has already been added to the DAG")
228 if isinstance(task, TaskGroup):
229 if self.dag:
230 if task.dag is not None and self.dag is not task.dag:
231 raise RuntimeError(
232 "Cannot mix TaskGroups from different DAGs: %s and %s", self.dag, task.dag
233 )
234 task.dag = self.dag
235 if task.children:
236 raise AirflowException("Cannot add a non-empty TaskGroup")
238 self.children[key] = task
240 def _remove(self, task: DAGNode) -> None:
241 key = task.node_id
243 if key not in self.children:
244 raise KeyError(f"Node id {key!r} not part of this task group")
246 self.used_group_ids.remove(key)
247 del self.children[key]
249 @property
250 def group_id(self) -> str | None:
251 """group_id of this TaskGroup."""
252 if self.task_group and self.task_group.prefix_group_id and self.task_group._group_id:
253 # defer to parent whether it adds a prefix
254 return self.task_group.child_id(self._group_id)
256 return self._group_id
258 @property
259 def label(self) -> str | None:
260 """group_id excluding parent's group_id used as the node label in UI."""
261 return self._group_id
263 def update_relative(
264 self, other: DependencyMixin, upstream: bool = True, edge_modifier: EdgeModifier | None = None
265 ) -> None:
266 """
267 Overrides TaskMixin.update_relative.
269 Update upstream_group_ids/downstream_group_ids/upstream_task_ids/downstream_task_ids
270 accordingly so that we can reduce the number of edges when displaying Graph view.
271 """
272 if isinstance(other, TaskGroup):
273 # Handles setting relationship between a TaskGroup and another TaskGroup
274 if upstream:
275 parent, child = (self, other)
276 if edge_modifier:
277 edge_modifier.add_edge_info(self.dag, other.downstream_join_id, self.upstream_join_id)
278 else:
279 parent, child = (other, self)
280 if edge_modifier:
281 edge_modifier.add_edge_info(self.dag, self.downstream_join_id, other.upstream_join_id)
283 parent.upstream_group_ids.add(child.group_id)
284 child.downstream_group_ids.add(parent.group_id)
285 else:
286 # Handles setting relationship between a TaskGroup and a task
287 for task in other.roots:
288 if not isinstance(task, DAGNode):
289 raise AirflowException(
290 "Relationships can only be set between TaskGroup "
291 f"or operators; received {task.__class__.__name__}"
292 )
294 # Do not set a relationship between a TaskGroup and a Label's roots
295 if self == task:
296 continue
298 if upstream:
299 self.upstream_task_ids.add(task.node_id)
300 if edge_modifier:
301 edge_modifier.add_edge_info(self.dag, task.node_id, self.upstream_join_id)
302 else:
303 self.downstream_task_ids.add(task.node_id)
304 if edge_modifier:
305 edge_modifier.add_edge_info(self.dag, self.downstream_join_id, task.node_id)
307 def _set_relatives(
308 self,
309 task_or_task_list: DependencyMixin | Sequence[DependencyMixin],
310 upstream: bool = False,
311 edge_modifier: EdgeModifier | None = None,
312 ) -> None:
313 """
314 Call set_upstream/set_downstream for all root/leaf tasks within this TaskGroup.
315 Update upstream_group_ids/downstream_group_ids/upstream_task_ids/downstream_task_ids.
316 """
317 if not isinstance(task_or_task_list, Sequence):
318 task_or_task_list = [task_or_task_list]
320 for task_like in task_or_task_list:
321 self.update_relative(task_like, upstream, edge_modifier=edge_modifier)
323 if upstream:
324 for task in self.get_roots():
325 task.set_upstream(task_or_task_list)
326 else:
327 for task in self.get_leaves():
328 task.set_downstream(task_or_task_list)
330 def __enter__(self) -> TaskGroup:
331 TaskGroupContext.push_context_managed_task_group(self)
332 return self
334 def __exit__(self, _type, _value, _tb):
335 TaskGroupContext.pop_context_managed_task_group()
337 def has_task(self, task: BaseOperator) -> bool:
338 """Returns True if this TaskGroup or its children TaskGroups contains the given task."""
339 if task.task_id in self.children:
340 return True
342 return any(child.has_task(task) for child in self.children.values() if isinstance(child, TaskGroup))
344 @property
345 def roots(self) -> list[BaseOperator]:
346 """Required by TaskMixin."""
347 return list(self.get_roots())
349 @property
350 def leaves(self) -> list[BaseOperator]:
351 """Required by TaskMixin."""
352 return list(self.get_leaves())
354 def get_roots(self) -> Generator[BaseOperator, None, None]:
355 """
356 Returns a generator of tasks that are root tasks, i.e. those with no upstream
357 dependencies within the TaskGroup.
358 """
359 for task in self:
360 if not any(self.has_task(parent) for parent in task.get_direct_relatives(upstream=True)):
361 yield task
363 def get_leaves(self) -> Generator[BaseOperator, None, None]:
364 """
365 Returns a generator of tasks that are leaf tasks, i.e. those with no downstream
366 dependencies within the TaskGroup.
367 """
368 for task in self:
369 if not any(self.has_task(child) for child in task.get_direct_relatives(upstream=False)):
370 yield task
372 def child_id(self, label):
373 """
374 Prefix label with group_id if prefix_group_id is True. Otherwise return the label
375 as-is.
376 """
377 if self.prefix_group_id:
378 group_id = self.group_id
379 if group_id:
380 return f"{group_id}.{label}"
382 return label
384 @property
385 def upstream_join_id(self) -> str:
386 """
387 If this TaskGroup has immediate upstream TaskGroups or tasks, a proxy node called
388 upstream_join_id will be created in Graph view to join the outgoing edges from this
389 TaskGroup to reduce the total number of edges needed to be displayed.
390 """
391 return f"{self.group_id}.upstream_join_id"
393 @property
394 def downstream_join_id(self) -> str:
395 """
396 If this TaskGroup has immediate downstream TaskGroups or tasks, a proxy node called
397 downstream_join_id will be created in Graph view to join the outgoing edges from this
398 TaskGroup to reduce the total number of edges needed to be displayed.
399 """
400 return f"{self.group_id}.downstream_join_id"
402 def get_task_group_dict(self) -> dict[str, TaskGroup]:
403 """Returns a flat dictionary of group_id: TaskGroup."""
404 task_group_map = {}
406 def build_map(task_group):
407 if not isinstance(task_group, TaskGroup):
408 return
410 task_group_map[task_group.group_id] = task_group
412 for child in task_group.children.values():
413 build_map(child)
415 build_map(self)
416 return task_group_map
418 def get_child_by_label(self, label: str) -> DAGNode:
419 """Get a child task/TaskGroup by its label (i.e. task_id/group_id without the group_id prefix)."""
420 return self.children[self.child_id(label)]
422 def serialize_for_task_group(self) -> tuple[DagAttributeTypes, Any]:
423 """Required by DAGNode."""
424 from airflow.serialization.serialized_objects import TaskGroupSerialization
426 return DagAttributeTypes.TASK_GROUP, TaskGroupSerialization.serialize_task_group(self)
428 def topological_sort(self, _include_subdag_tasks: bool = False):
429 """
430 Sorts children in topographical order, such that a task comes after any of its
431 upstream dependencies.
433 :return: list of tasks in topological order
434 """
435 # This uses a modified version of Kahn's Topological Sort algorithm to
436 # not have to pre-compute the "in-degree" of the nodes.
437 from airflow.operators.subdag import SubDagOperator # Avoid circular import
439 graph_unsorted = copy.copy(self.children)
441 graph_sorted: list[DAGNode] = []
443 # special case
444 if len(self.children) == 0:
445 return graph_sorted
447 # Run until the unsorted graph is empty.
448 while graph_unsorted:
449 # Go through each of the node/edges pairs in the unsorted graph. If a set of edges doesn't contain
450 # any nodes that haven't been resolved, that is, that are still in the unsorted graph, remove the
451 # pair from the unsorted graph, and append it to the sorted graph. Note here that by using using
452 # the values() method for iterating, a copy of the unsorted graph is used, allowing us to modify
453 # the unsorted graph as we move through it.
454 #
455 # We also keep a flag for checking that graph is acyclic, which is true if any nodes are resolved
456 # during each pass through the graph. If not, we need to exit as the graph therefore can't be
457 # sorted.
458 acyclic = False
459 for node in list(graph_unsorted.values()):
460 for edge in node.upstream_list:
461 if edge.node_id in graph_unsorted:
462 break
463 # Check for task's group is a child (or grand child) of this TG,
464 tg = edge.task_group
465 while tg:
466 if tg.node_id in graph_unsorted:
467 break
468 tg = tg.task_group
470 if tg:
471 # We are already going to visit that TG
472 break
473 else:
474 acyclic = True
475 del graph_unsorted[node.node_id]
476 graph_sorted.append(node)
477 if _include_subdag_tasks and isinstance(node, SubDagOperator):
478 graph_sorted.extend(
479 node.subdag.task_group.topological_sort(_include_subdag_tasks=True)
480 )
482 if not acyclic:
483 raise AirflowDagCycleException(f"A cyclic dependency occurred in dag: {self.dag_id}")
485 return graph_sorted
487 def iter_mapped_task_groups(self) -> Iterator[MappedTaskGroup]:
488 """Return mapped task groups in the hierarchy.
490 Groups are returned from the closest to the outmost. If *self* is a
491 mapped task group, it is returned first.
493 :meta private:
494 """
495 group: TaskGroup | None = self
496 while group is not None:
497 if isinstance(group, MappedTaskGroup):
498 yield group
499 group = group.task_group
501 def iter_tasks(self) -> Iterator[AbstractOperator]:
502 """Returns an iterator of the child tasks."""
503 from airflow.models.abstractoperator import AbstractOperator
505 groups_to_visit = [self]
507 while groups_to_visit:
508 visiting = groups_to_visit.pop(0)
510 for child in visiting.children.values():
511 if isinstance(child, AbstractOperator):
512 yield child
513 elif isinstance(child, TaskGroup):
514 groups_to_visit.append(child)
515 else:
516 raise ValueError(
517 f"Encountered a DAGNode that is not a TaskGroup or an AbstractOperator: {type(child)}"
518 )
521class MappedTaskGroup(TaskGroup):
522 """A mapped task group.
524 This doesn't really do anything special, just holds some additional metadata
525 for expansion later.
527 Don't instantiate this class directly; call *expand* or *expand_kwargs* on
528 a ``@task_group`` function instead.
529 """
531 def __init__(self, *, expand_input: ExpandInput, **kwargs: Any) -> None:
532 super().__init__(**kwargs)
533 self._expand_input = expand_input
534 for op, _ in expand_input.iter_references():
535 self.set_upstream(op)
537 def iter_mapped_dependencies(self) -> Iterator[Operator]:
538 """Upstream dependencies that provide XComs used by this mapped task group."""
539 from airflow.models.xcom_arg import XComArg
541 for op, _ in XComArg.iter_xcom_references(self._expand_input):
542 yield op
544 @cache
545 def get_parse_time_mapped_ti_count(self) -> int:
546 """Number of instances a task in this group should be mapped to, when a DAG run is created.
548 This only considers literal mapped arguments, and would return *None*
549 when any non-literal values are used for mapping.
551 If this group is inside mapped task groups, all the nested counts are
552 multiplied and accounted.
554 :meta private:
556 :raise NotFullyPopulated: If any non-literal mapped arguments are encountered.
557 :return: The total number of mapped instances each task should have.
558 """
559 return functools.reduce(
560 operator.mul,
561 (g._expand_input.get_parse_time_mapped_ti_count() for g in self.iter_mapped_task_groups()),
562 )
564 def get_mapped_ti_count(self, run_id: str, *, session: Session) -> int:
565 """Number of instances a task in this group should be mapped to at run time.
567 This considers both literal and non-literal mapped arguments, and the
568 result is therefore available when all depended tasks have finished. The
569 return value should be identical to ``parse_time_mapped_ti_count`` if
570 all mapped arguments are literal.
572 If this group is inside mapped task groups, all the nested counts are
573 multiplied and accounted.
575 :meta private:
577 :raise NotFullyPopulated: If upstream tasks are not all complete yet.
578 :return: Total number of mapped TIs this task should have.
579 """
580 groups = self.iter_mapped_task_groups()
581 return functools.reduce(
582 operator.mul,
583 (g._expand_input.get_total_map_length(run_id, session=session) for g in groups),
584 )
587class TaskGroupContext:
588 """TaskGroup context is used to keep the current TaskGroup when TaskGroup is used as ContextManager."""
590 _context_managed_task_group: TaskGroup | None = None
591 _previous_context_managed_task_groups: list[TaskGroup] = []
593 @classmethod
594 def push_context_managed_task_group(cls, task_group: TaskGroup):
595 """Push a TaskGroup into the list of managed TaskGroups."""
596 if cls._context_managed_task_group:
597 cls._previous_context_managed_task_groups.append(cls._context_managed_task_group)
598 cls._context_managed_task_group = task_group
600 @classmethod
601 def pop_context_managed_task_group(cls) -> TaskGroup | None:
602 """Pops the last TaskGroup from the list of manged TaskGroups and update the current TaskGroup."""
603 old_task_group = cls._context_managed_task_group
604 if cls._previous_context_managed_task_groups:
605 cls._context_managed_task_group = cls._previous_context_managed_task_groups.pop()
606 else:
607 cls._context_managed_task_group = None
608 return old_task_group
610 @classmethod
611 def get_current_task_group(cls, dag: DAG | None) -> TaskGroup | None:
612 """Get the current TaskGroup."""
613 from airflow.models.dag import DagContext
615 if not cls._context_managed_task_group:
616 dag = dag or DagContext.get_current_dag()
617 if dag:
618 # If there's currently a DAG but no TaskGroup, return the root TaskGroup of the dag.
619 return dag.task_group
621 return cls._context_managed_task_group
624def task_group_to_dict(task_item_or_group):
625 """
626 Create a nested dict representation of this TaskGroup and its children used to construct
627 the Graph.
628 """
629 from airflow.models.abstractoperator import AbstractOperator
631 if isinstance(task_item_or_group, AbstractOperator):
632 return {
633 "id": task_item_or_group.task_id,
634 "value": {
635 "label": task_item_or_group.label,
636 "labelStyle": f"fill:{task_item_or_group.ui_fgcolor};",
637 "style": f"fill:{task_item_or_group.ui_color};",
638 "rx": 5,
639 "ry": 5,
640 },
641 }
642 task_group = task_item_or_group
643 is_mapped = isinstance(task_group, MappedTaskGroup)
644 children = [
645 task_group_to_dict(child) for child in sorted(task_group.children.values(), key=lambda t: t.label)
646 ]
648 if task_group.upstream_group_ids or task_group.upstream_task_ids:
649 children.append(
650 {
651 "id": task_group.upstream_join_id,
652 "value": {
653 "label": "",
654 "labelStyle": f"fill:{task_group.ui_fgcolor};",
655 "style": f"fill:{task_group.ui_color};",
656 "shape": "circle",
657 },
658 }
659 )
661 if task_group.downstream_group_ids or task_group.downstream_task_ids:
662 # This is the join node used to reduce the number of edges between two TaskGroup.
663 children.append(
664 {
665 "id": task_group.downstream_join_id,
666 "value": {
667 "label": "",
668 "labelStyle": f"fill:{task_group.ui_fgcolor};",
669 "style": f"fill:{task_group.ui_color};",
670 "shape": "circle",
671 },
672 }
673 )
675 return {
676 "id": task_group.group_id,
677 "value": {
678 "label": task_group.label,
679 "labelStyle": f"fill:{task_group.ui_fgcolor};",
680 "style": f"fill:{task_group.ui_color}",
681 "rx": 5,
682 "ry": 5,
683 "clusterLabelPos": "top",
684 "tooltip": task_group.tooltip,
685 "isMapped": is_mapped,
686 },
687 "children": children,
688 }