1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements. See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership. The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License. You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied. See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18from __future__ import annotations
19
20import re
21from abc import ABCMeta, abstractmethod
22from collections.abc import Sequence
23from datetime import datetime
24from typing import TYPE_CHECKING, Any
25
26from airflow.sdk._shared.dagnode.node import GenericDAGNode
27from airflow.sdk.definitions._internal.mixins import DependencyMixin
28
29if TYPE_CHECKING:
30 from airflow.sdk.definitions.dag import DAG
31 from airflow.sdk.definitions.edges import EdgeModifier
32 from airflow.sdk.definitions.taskgroup import TaskGroup # noqa: F401
33 from airflow.sdk.types import Operator # noqa: F401
34 from airflow.serialization.enums import DagAttributeTypes
35
36
37KEY_REGEX = re.compile(r"^[\w.-]+$")
38GROUP_KEY_REGEX = re.compile(r"^[\w-]+$")
39CAMELCASE_TO_SNAKE_CASE_REGEX = re.compile(r"(?!^)([A-Z]+)")
40
41
42def validate_key(k: str, max_length: int = 250):
43 """Validate value used as a key."""
44 if not isinstance(k, str):
45 raise TypeError(f"The key has to be a string and is {type(k)}:{k}")
46 if (length := len(k)) > max_length:
47 raise ValueError(f"The key has to be less than {max_length} characters, not {length}")
48 if not KEY_REGEX.match(k):
49 raise ValueError(
50 f"The key {k!r} has to be made of alphanumeric characters, dashes, "
51 f"dots, and underscores exclusively"
52 )
53
54
55def validate_group_key(k: str, max_length: int = 200):
56 """Validate value used as a group key."""
57 if not isinstance(k, str):
58 raise TypeError(f"The key has to be a string and is {type(k)}:{k}")
59 if (length := len(k)) > max_length:
60 raise ValueError(f"The key has to be less than {max_length} characters, not {length}")
61 if not GROUP_KEY_REGEX.match(k):
62 raise ValueError(
63 f"The key {k!r} has to be made of alphanumeric characters, dashes, and underscores exclusively"
64 )
65
66
67class DAGNode(GenericDAGNode["DAG", "Operator", "TaskGroup"], DependencyMixin, metaclass=ABCMeta):
68 """
69 A base class for a node in the graph of a workflow.
70
71 A node may be an Operator or a Task Group, either mapped or unmapped.
72 """
73
74 start_date: datetime | None
75 end_date: datetime | None
76
77 @property
78 @abstractmethod
79 def roots(self) -> Sequence[DAGNode]:
80 raise NotImplementedError()
81
82 @property
83 @abstractmethod
84 def leaves(self) -> Sequence[DAGNode]:
85 raise NotImplementedError()
86
87 def _set_relatives(
88 self,
89 task_or_task_list: DependencyMixin | Sequence[DependencyMixin],
90 upstream: bool = False,
91 edge_modifier: EdgeModifier | None = None,
92 ) -> None:
93 """Set relatives for the task or task list."""
94 from airflow.sdk.bases.operator import BaseOperator
95 from airflow.sdk.definitions.mappedoperator import MappedOperator
96
97 if not isinstance(task_or_task_list, Sequence):
98 task_or_task_list = [task_or_task_list]
99
100 task_list: list[BaseOperator | MappedOperator] = []
101 for task_object in task_or_task_list:
102 task_object.update_relative(self, not upstream, edge_modifier=edge_modifier)
103 relatives = task_object.leaves if upstream else task_object.roots
104 for task in relatives:
105 if not isinstance(task, (BaseOperator, MappedOperator)):
106 raise TypeError(
107 f"Relationships can only be set between Operators; received {task.__class__.__name__}"
108 )
109 task_list.append(task)
110
111 # relationships can only be set if the tasks share a single Dag. Tasks
112 # without a Dag are assigned to that Dag.
113 dags: set[DAG] = {task.dag for task in [*self.roots, *task_list] if task.has_dag() and task.dag}
114
115 if len(dags) > 1:
116 raise RuntimeError(f"Tried to set relationships between tasks in more than one Dag: {dags}")
117 if len(dags) == 1:
118 dag = dags.pop()
119 else:
120 raise ValueError(
121 "Tried to create relationships between tasks that don't have Dags yet. "
122 f"Set the Dag for at least one task and try again: {[self, *task_list]}"
123 )
124
125 if not self.has_dag():
126 # If this task does not yet have a Dag, add it to the same Dag as the other task.
127 self.dag = dag
128
129 for task in task_list:
130 if dag and not task.has_dag():
131 # If the other task does not yet have a Dag, add it to the same Dag as this task and
132 dag.add_task(task) # type: ignore[arg-type]
133 if upstream:
134 task.downstream_task_ids.add(self.node_id)
135 self.upstream_task_ids.add(task.node_id)
136 if edge_modifier:
137 edge_modifier.add_edge_info(dag, task.node_id, self.node_id)
138 else:
139 self.downstream_task_ids.add(task.node_id)
140 task.upstream_task_ids.add(self.node_id)
141 if edge_modifier:
142 edge_modifier.add_edge_info(dag, self.node_id, task.node_id)
143
144 def set_downstream(
145 self,
146 task_or_task_list: DependencyMixin | Sequence[DependencyMixin],
147 edge_modifier: EdgeModifier | None = None,
148 ) -> None:
149 """Set a node (or nodes) to be directly downstream from the current node."""
150 self._set_relatives(task_or_task_list, upstream=False, edge_modifier=edge_modifier)
151
152 def set_upstream(
153 self,
154 task_or_task_list: DependencyMixin | Sequence[DependencyMixin],
155 edge_modifier: EdgeModifier | None = None,
156 ) -> None:
157 """Set a node (or nodes) to be directly upstream from the current node."""
158 self._set_relatives(task_or_task_list, upstream=True, edge_modifier=edge_modifier)
159
160 def serialize_for_task_group(self) -> tuple[DagAttributeTypes, Any]:
161 """Serialize a task group's content; used by TaskGroupSerialization."""
162 raise NotImplementedError()