Coverage for /pythoncovmergedfiles/medio/medio/src/airflow/build/lib/airflow/utils/task_group.py: 39%
289 statements
« prev ^ index » next coverage.py v7.0.1, created at 2022-12-25 06:11 +0000
« prev ^ index » next coverage.py v7.0.1, created at 2022-12-25 06:11 +0000
1#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18"""
19A TaskGroup is a collection of closely related tasks on the same DAG that should be grouped
20together when the DAG is displayed graphically.
21"""
22from __future__ import annotations
24import copy
25import functools
26import operator
27import re
28import weakref
29from typing import TYPE_CHECKING, Any, Generator, Iterator, Sequence
31from airflow.compat.functools import cache
32from airflow.exceptions import (
33 AirflowDagCycleException,
34 AirflowException,
35 DuplicateTaskIdFound,
36 TaskAlreadyInTaskGroup,
37)
38from airflow.models.taskmixin import DAGNode, DependencyMixin
39from airflow.serialization.enums import DagAttributeTypes
40from airflow.utils.helpers import validate_group_key
42if TYPE_CHECKING:
43 from sqlalchemy.orm import Session
45 from airflow.models.abstractoperator import AbstractOperator
46 from airflow.models.baseoperator import BaseOperator
47 from airflow.models.dag import DAG
48 from airflow.models.expandinput import ExpandInput
49 from airflow.models.operator import Operator
50 from airflow.utils.edgemodifier import EdgeModifier
53class TaskGroup(DAGNode):
54 """
55 A collection of tasks. When set_downstream() or set_upstream() are called on the
56 TaskGroup, it is applied across all tasks within the group if necessary.
58 :param group_id: a unique, meaningful id for the TaskGroup. group_id must not conflict
59 with group_id of TaskGroup or task_id of tasks in the DAG. Root TaskGroup has group_id
60 set to None.
61 :param prefix_group_id: If set to True, child task_id and group_id will be prefixed with
62 this TaskGroup's group_id. If set to False, child task_id and group_id are not prefixed.
63 Default is True.
64 :param parent_group: The parent TaskGroup of this TaskGroup. parent_group is set to None
65 for the root TaskGroup.
66 :param dag: The DAG that this TaskGroup belongs to.
67 :param default_args: A dictionary of default parameters to be used
68 as constructor keyword parameters when initialising operators,
69 will override default_args defined in the DAG level.
70 Note that operators have the same hook, and precede those defined
71 here, meaning that if your dict contains `'depends_on_past': True`
72 here and `'depends_on_past': False` in the operator's call
73 `default_args`, the actual value will be `False`.
74 :param tooltip: The tooltip of the TaskGroup node when displayed in the UI
75 :param ui_color: The fill color of the TaskGroup node when displayed in the UI
76 :param ui_fgcolor: The label color of the TaskGroup node when displayed in the UI
77 :param add_suffix_on_collision: If this task group name already exists,
78 automatically add `__1` etc suffixes
79 """
81 used_group_ids: set[str | None]
83 def __init__(
84 self,
85 group_id: str | None,
86 prefix_group_id: bool = True,
87 parent_group: TaskGroup | None = None,
88 dag: DAG | None = None,
89 default_args: dict[str, Any] | None = None,
90 tooltip: str = "",
91 ui_color: str = "CornflowerBlue",
92 ui_fgcolor: str = "#000",
93 add_suffix_on_collision: bool = False,
94 ):
95 from airflow.models.dag import DagContext
97 self.prefix_group_id = prefix_group_id
98 self.default_args = copy.deepcopy(default_args or {})
100 dag = dag or DagContext.get_current_dag()
102 if group_id is None:
103 # This creates a root TaskGroup.
104 if parent_group:
105 raise AirflowException("Root TaskGroup cannot have parent_group")
106 # used_group_ids is shared across all TaskGroups in the same DAG to keep track
107 # of used group_id to avoid duplication.
108 self.used_group_ids = set()
109 self.dag = dag
110 else:
111 if prefix_group_id:
112 # If group id is used as prefix, it should not contain spaces nor dots
113 # because it is used as prefix in the task_id
114 validate_group_key(group_id)
115 else:
116 if not isinstance(group_id, str):
117 raise ValueError("group_id must be str")
118 if not group_id:
119 raise ValueError("group_id must not be empty")
121 if not parent_group and not dag:
122 raise AirflowException("TaskGroup can only be used inside a dag")
124 parent_group = parent_group or TaskGroupContext.get_current_task_group(dag)
125 if not parent_group:
126 raise AirflowException("TaskGroup must have a parent_group except for the root TaskGroup")
127 if dag is not parent_group.dag:
128 raise RuntimeError(
129 "Cannot mix TaskGroups from different DAGs: %s and %s", dag, parent_group.dag
130 )
132 self.used_group_ids = parent_group.used_group_ids
134 # if given group_id already used assign suffix by incrementing largest used suffix integer
135 # Example : task_group ==> task_group__1 -> task_group__2 -> task_group__3
136 self._group_id = group_id
137 self._check_for_group_id_collisions(add_suffix_on_collision)
139 self.children: dict[str, DAGNode] = {}
140 if parent_group:
141 parent_group.add(self)
143 self.used_group_ids.add(self.group_id)
144 if self.group_id:
145 self.used_group_ids.add(self.downstream_join_id)
146 self.used_group_ids.add(self.upstream_join_id)
148 self.tooltip = tooltip
149 self.ui_color = ui_color
150 self.ui_fgcolor = ui_fgcolor
152 # Keep track of TaskGroups or tasks that depend on this entire TaskGroup separately
153 # so that we can optimize the number of edges when entire TaskGroups depend on each other.
154 self.upstream_group_ids: set[str | None] = set()
155 self.downstream_group_ids: set[str | None] = set()
156 self.upstream_task_ids = set()
157 self.downstream_task_ids = set()
159 def _check_for_group_id_collisions(self, add_suffix_on_collision: bool):
160 if self._group_id is None:
161 return
162 # if given group_id already used assign suffix by incrementing largest used suffix integer
163 # Example : task_group ==> task_group__1 -> task_group__2 -> task_group__3
164 if self._group_id in self.used_group_ids:
165 if not add_suffix_on_collision:
166 raise DuplicateTaskIdFound(f"group_id '{self._group_id}' has already been added to the DAG")
167 base = re.split(r"__\d+$", self._group_id)[0]
168 suffixes = sorted(
169 int(re.split(r"^.+__", used_group_id)[1])
170 for used_group_id in self.used_group_ids
171 if used_group_id is not None and re.match(rf"^{base}__\d+$", used_group_id)
172 )
173 if not suffixes:
174 self._group_id += "__1"
175 else:
176 self._group_id = f"{base}__{suffixes[-1] + 1}"
178 @classmethod
179 def create_root(cls, dag: DAG) -> TaskGroup:
180 """Create a root TaskGroup with no group_id or parent."""
181 return cls(group_id=None, dag=dag)
183 @property
184 def node_id(self):
185 return self.group_id
187 @property
188 def is_root(self) -> bool:
189 """Returns True if this TaskGroup is the root TaskGroup. Otherwise False"""
190 return not self.group_id
192 @property
193 def parent_group(self) -> TaskGroup | None:
194 return self.task_group
196 def __iter__(self):
197 for child in self.children.values():
198 if isinstance(child, TaskGroup):
199 yield from child
200 else:
201 yield child
203 def add(self, task: DAGNode) -> None:
204 """Add a task to this TaskGroup.
206 :meta private:
207 """
208 from airflow.models.abstractoperator import AbstractOperator
210 existing_tg = task.task_group
211 if isinstance(task, AbstractOperator) and existing_tg is not None and existing_tg != self:
212 raise TaskAlreadyInTaskGroup(task.node_id, existing_tg.node_id, self.node_id)
214 # Set the TG first, as setting it might change the return value of node_id!
215 task.task_group = weakref.proxy(self)
216 key = task.node_id
218 if key in self.children:
219 node_type = "Task" if hasattr(task, "task_id") else "Task Group"
220 raise DuplicateTaskIdFound(f"{node_type} id '{key}' has already been added to the DAG")
222 if isinstance(task, TaskGroup):
223 if self.dag:
224 if task.dag is not None and self.dag is not task.dag:
225 raise RuntimeError(
226 "Cannot mix TaskGroups from different DAGs: %s and %s", self.dag, task.dag
227 )
228 task.dag = self.dag
229 if task.children:
230 raise AirflowException("Cannot add a non-empty TaskGroup")
232 self.children[key] = task
234 def _remove(self, task: DAGNode) -> None:
235 key = task.node_id
237 if key not in self.children:
238 raise KeyError(f"Node id {key!r} not part of this task group")
240 self.used_group_ids.remove(key)
241 del self.children[key]
243 @property
244 def group_id(self) -> str | None:
245 """group_id of this TaskGroup."""
246 if self.task_group and self.task_group.prefix_group_id and self.task_group.group_id:
247 return self.task_group.child_id(self._group_id)
249 return self._group_id
251 @property
252 def label(self) -> str | None:
253 """group_id excluding parent's group_id used as the node label in UI."""
254 return self._group_id
256 def update_relative(self, other: DependencyMixin, upstream=True) -> None:
257 """
258 Overrides TaskMixin.update_relative.
260 Update upstream_group_ids/downstream_group_ids/upstream_task_ids/downstream_task_ids
261 accordingly so that we can reduce the number of edges when displaying Graph view.
262 """
263 if isinstance(other, TaskGroup):
264 # Handles setting relationship between a TaskGroup and another TaskGroup
265 if upstream:
266 parent, child = (self, other)
267 else:
268 parent, child = (other, self)
270 parent.upstream_group_ids.add(child.group_id)
271 child.downstream_group_ids.add(parent.group_id)
272 else:
273 # Handles setting relationship between a TaskGroup and a task
274 for task in other.roots:
275 if not isinstance(task, DAGNode):
276 raise AirflowException(
277 "Relationships can only be set between TaskGroup "
278 f"or operators; received {task.__class__.__name__}"
279 )
281 if upstream:
282 self.upstream_task_ids.add(task.node_id)
283 else:
284 self.downstream_task_ids.add(task.node_id)
286 def _set_relatives(
287 self,
288 task_or_task_list: DependencyMixin | Sequence[DependencyMixin],
289 upstream: bool = False,
290 edge_modifier: EdgeModifier | None = None,
291 ) -> None:
292 """
293 Call set_upstream/set_downstream for all root/leaf tasks within this TaskGroup.
294 Update upstream_group_ids/downstream_group_ids/upstream_task_ids/downstream_task_ids.
295 """
296 if not isinstance(task_or_task_list, Sequence):
297 task_or_task_list = [task_or_task_list]
299 for task_like in task_or_task_list:
300 self.update_relative(task_like, upstream)
302 if upstream:
303 for task in self.get_roots():
304 task.set_upstream(task_or_task_list)
305 else:
306 for task in self.get_leaves():
307 task.set_downstream(task_or_task_list)
309 def __enter__(self) -> TaskGroup:
310 TaskGroupContext.push_context_managed_task_group(self)
311 return self
313 def __exit__(self, _type, _value, _tb):
314 TaskGroupContext.pop_context_managed_task_group()
316 def has_task(self, task: BaseOperator) -> bool:
317 """Returns True if this TaskGroup or its children TaskGroups contains the given task."""
318 if task.task_id in self.children:
319 return True
321 return any(child.has_task(task) for child in self.children.values() if isinstance(child, TaskGroup))
323 @property
324 def roots(self) -> list[BaseOperator]:
325 """Required by TaskMixin"""
326 return list(self.get_roots())
328 @property
329 def leaves(self) -> list[BaseOperator]:
330 """Required by TaskMixin"""
331 return list(self.get_leaves())
333 def get_roots(self) -> Generator[BaseOperator, None, None]:
334 """
335 Returns a generator of tasks that are root tasks, i.e. those with no upstream
336 dependencies within the TaskGroup.
337 """
338 for task in self:
339 if not any(self.has_task(parent) for parent in task.get_direct_relatives(upstream=True)):
340 yield task
342 def get_leaves(self) -> Generator[BaseOperator, None, None]:
343 """
344 Returns a generator of tasks that are leaf tasks, i.e. those with no downstream
345 dependencies within the TaskGroup
346 """
347 for task in self:
348 if not any(self.has_task(child) for child in task.get_direct_relatives(upstream=False)):
349 yield task
351 def child_id(self, label):
352 """
353 Prefix label with group_id if prefix_group_id is True. Otherwise return the label
354 as-is.
355 """
356 if self.prefix_group_id and self.group_id:
357 return f"{self.group_id}.{label}"
359 return label
361 @property
362 def upstream_join_id(self) -> str:
363 """
364 If this TaskGroup has immediate upstream TaskGroups or tasks, a dummy node called
365 upstream_join_id will be created in Graph view to join the outgoing edges from this
366 TaskGroup to reduce the total number of edges needed to be displayed.
367 """
368 return f"{self.group_id}.upstream_join_id"
370 @property
371 def downstream_join_id(self) -> str:
372 """
373 If this TaskGroup has immediate downstream TaskGroups or tasks, a dummy node called
374 downstream_join_id will be created in Graph view to join the outgoing edges from this
375 TaskGroup to reduce the total number of edges needed to be displayed.
376 """
377 return f"{self.group_id}.downstream_join_id"
379 def get_task_group_dict(self) -> dict[str, TaskGroup]:
380 """Returns a flat dictionary of group_id: TaskGroup"""
381 task_group_map = {}
383 def build_map(task_group):
384 if not isinstance(task_group, TaskGroup):
385 return
387 task_group_map[task_group.group_id] = task_group
389 for child in task_group.children.values():
390 build_map(child)
392 build_map(self)
393 return task_group_map
395 def get_child_by_label(self, label: str) -> DAGNode:
396 """Get a child task/TaskGroup by its label (i.e. task_id/group_id without the group_id prefix)"""
397 return self.children[self.child_id(label)]
399 def serialize_for_task_group(self) -> tuple[DagAttributeTypes, Any]:
400 """Required by DAGNode."""
401 from airflow.serialization.serialized_objects import TaskGroupSerialization
403 return DagAttributeTypes.TASK_GROUP, TaskGroupSerialization.serialize_task_group(self)
405 def topological_sort(self, _include_subdag_tasks: bool = False):
406 """
407 Sorts children in topographical order, such that a task comes after any of its
408 upstream dependencies.
410 :return: list of tasks in topological order
411 """
412 # This uses a modified version of Kahn's Topological Sort algorithm to
413 # not have to pre-compute the "in-degree" of the nodes.
414 from airflow.operators.subdag import SubDagOperator # Avoid circular import
416 graph_unsorted = copy.copy(self.children)
418 graph_sorted: list[DAGNode] = []
420 # special case
421 if len(self.children) == 0:
422 return graph_sorted
424 # Run until the unsorted graph is empty.
425 while graph_unsorted:
426 # Go through each of the node/edges pairs in the unsorted graph. If a set of edges doesn't contain
427 # any nodes that haven't been resolved, that is, that are still in the unsorted graph, remove the
428 # pair from the unsorted graph, and append it to the sorted graph. Note here that by using using
429 # the values() method for iterating, a copy of the unsorted graph is used, allowing us to modify
430 # the unsorted graph as we move through it.
431 #
432 # We also keep a flag for checking that graph is acyclic, which is true if any nodes are resolved
433 # during each pass through the graph. If not, we need to exit as the graph therefore can't be
434 # sorted.
435 acyclic = False
436 for node in list(graph_unsorted.values()):
437 for edge in node.upstream_list:
438 if edge.node_id in graph_unsorted:
439 break
440 # Check for task's group is a child (or grand child) of this TG,
441 tg = edge.task_group
442 while tg:
443 if tg.node_id in graph_unsorted:
444 break
445 tg = tg.task_group
447 if tg:
448 # We are already going to visit that TG
449 break
450 else:
451 acyclic = True
452 del graph_unsorted[node.node_id]
453 graph_sorted.append(node)
454 if _include_subdag_tasks and isinstance(node, SubDagOperator):
455 graph_sorted.extend(
456 node.subdag.task_group.topological_sort(_include_subdag_tasks=True)
457 )
459 if not acyclic:
460 raise AirflowDagCycleException(f"A cyclic dependency occurred in dag: {self.dag_id}")
462 return graph_sorted
464 def iter_mapped_task_groups(self) -> Iterator[MappedTaskGroup]:
465 """Return mapped task groups in the hierarchy.
467 Groups are returned from the closest to the outmost. If *self* is a
468 mapped task group, it is returned first.
470 :meta private:
471 """
472 group: TaskGroup | None = self
473 while group is not None:
474 if isinstance(group, MappedTaskGroup):
475 yield group
476 group = group.task_group
478 def iter_tasks(self) -> Iterator[AbstractOperator]:
479 """Returns an iterator of the child tasks."""
480 from airflow.models.abstractoperator import AbstractOperator
482 groups_to_visit = [self]
484 while groups_to_visit:
485 visiting = groups_to_visit.pop(0)
487 for child in visiting.children.values():
488 if isinstance(child, AbstractOperator):
489 yield child
490 elif isinstance(child, TaskGroup):
491 groups_to_visit.append(child)
492 else:
493 raise ValueError(
494 f"Encountered a DAGNode that is not a TaskGroup or an AbstractOperator: {type(child)}"
495 )
498class MappedTaskGroup(TaskGroup):
499 """A mapped task group.
501 This doesn't really do anything special, just holds some additional metadata
502 for expansion later.
504 Don't instantiate this class directly; call *expand* or *expand_kwargs* on
505 a ``@task_group`` function instead.
506 """
508 def __init__(self, *, expand_input: ExpandInput, **kwargs: Any) -> None:
509 super().__init__(**kwargs)
510 self._expand_input = expand_input
511 for op, _ in expand_input.iter_references():
512 self.set_upstream(op)
514 def iter_mapped_dependencies(self) -> Iterator[Operator]:
515 """Upstream dependencies that provide XComs used by this mapped task group."""
516 from airflow.models.xcom_arg import XComArg
518 for op, _ in XComArg.iter_xcom_references(self._expand_input):
519 yield op
521 @cache
522 def get_parse_time_mapped_ti_count(self) -> int:
523 """Number of instances a task in this group should be mapped to, when a DAG run is created.
525 This only considers literal mapped arguments, and would return *None*
526 when any non-literal values are used for mapping.
528 If this group is inside mapped task groups, all the nested counts are
529 multiplied and accounted.
531 :meta private:
533 :raise NotFullyPopulated: If any non-literal mapped arguments are encountered.
534 :return: The total number of mapped instances each task should have.
535 """
536 return functools.reduce(
537 operator.mul,
538 (g._expand_input.get_parse_time_mapped_ti_count() for g in self.iter_mapped_task_groups()),
539 )
541 def get_mapped_ti_count(self, run_id: str, *, session: Session) -> int:
542 """Number of instances a task in this group should be mapped to at run time.
544 This considers both literal and non-literal mapped arguments, and the
545 result is therefore available when all depended tasks have finished. The
546 return value should be identical to ``parse_time_mapped_ti_count`` if
547 all mapped arguments are literal.
549 If this group is inside mapped task groups, all the nested counts are
550 multiplied and accounted.
552 :meta private:
554 :raise NotFullyPopulated: If upstream tasks are not all complete yet.
555 :return: Total number of mapped TIs this task should have.
556 """
557 groups = self.iter_mapped_task_groups()
558 return functools.reduce(
559 operator.mul,
560 (g._expand_input.get_total_map_length(run_id, session=session) for g in groups),
561 )
564class TaskGroupContext:
565 """TaskGroup context is used to keep the current TaskGroup when TaskGroup is used as ContextManager."""
567 _context_managed_task_group: TaskGroup | None = None
568 _previous_context_managed_task_groups: list[TaskGroup] = []
570 @classmethod
571 def push_context_managed_task_group(cls, task_group: TaskGroup):
572 """Push a TaskGroup into the list of managed TaskGroups."""
573 if cls._context_managed_task_group:
574 cls._previous_context_managed_task_groups.append(cls._context_managed_task_group)
575 cls._context_managed_task_group = task_group
577 @classmethod
578 def pop_context_managed_task_group(cls) -> TaskGroup | None:
579 """Pops the last TaskGroup from the list of manged TaskGroups and update the current TaskGroup."""
580 old_task_group = cls._context_managed_task_group
581 if cls._previous_context_managed_task_groups:
582 cls._context_managed_task_group = cls._previous_context_managed_task_groups.pop()
583 else:
584 cls._context_managed_task_group = None
585 return old_task_group
587 @classmethod
588 def get_current_task_group(cls, dag: DAG | None) -> TaskGroup | None:
589 """Get the current TaskGroup."""
590 from airflow.models.dag import DagContext
592 if not cls._context_managed_task_group:
593 dag = dag or DagContext.get_current_dag()
594 if dag:
595 # If there's currently a DAG but no TaskGroup, return the root TaskGroup of the dag.
596 return dag.task_group
598 return cls._context_managed_task_group
601def task_group_to_dict(task_item_or_group):
602 """
603 Create a nested dict representation of this TaskGroup and its children used to construct
604 the Graph.
605 """
606 from airflow.models.abstractoperator import AbstractOperator
608 if isinstance(task_item_or_group, AbstractOperator):
609 return {
610 "id": task_item_or_group.task_id,
611 "value": {
612 "label": task_item_or_group.label,
613 "labelStyle": f"fill:{task_item_or_group.ui_fgcolor};",
614 "style": f"fill:{task_item_or_group.ui_color};",
615 "rx": 5,
616 "ry": 5,
617 },
618 }
619 task_group = task_item_or_group
620 is_mapped = isinstance(task_group, MappedTaskGroup)
621 children = [
622 task_group_to_dict(child) for child in sorted(task_group.children.values(), key=lambda t: t.label)
623 ]
625 if task_group.upstream_group_ids or task_group.upstream_task_ids:
626 children.append(
627 {
628 "id": task_group.upstream_join_id,
629 "value": {
630 "label": "",
631 "labelStyle": f"fill:{task_group.ui_fgcolor};",
632 "style": f"fill:{task_group.ui_color};",
633 "shape": "circle",
634 },
635 }
636 )
638 if task_group.downstream_group_ids or task_group.downstream_task_ids:
639 # This is the join node used to reduce the number of edges between two TaskGroup.
640 children.append(
641 {
642 "id": task_group.downstream_join_id,
643 "value": {
644 "label": "",
645 "labelStyle": f"fill:{task_group.ui_fgcolor};",
646 "style": f"fill:{task_group.ui_color};",
647 "shape": "circle",
648 },
649 }
650 )
652 return {
653 "id": task_group.group_id,
654 "value": {
655 "label": task_group.label,
656 "labelStyle": f"fill:{task_group.ui_fgcolor};",
657 "style": f"fill:{task_group.ui_color}",
658 "rx": 5,
659 "ry": 5,
660 "clusterLabelPos": "top",
661 "tooltip": task_group.tooltip,
662 "isMapped": is_mapped,
663 },
664 "children": children,
665 }