Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/airflow/sdk/definitions/_internal/node.py: 33%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

82 statements  

1# Licensed to the Apache Software Foundation (ASF) under one 

2# or more contributor license agreements. See the NOTICE file 

3# distributed with this work for additional information 

4# regarding copyright ownership. The ASF licenses this file 

5# to you under the Apache License, Version 2.0 (the 

6# "License"); you may not use this file except in compliance 

7# with the License. You may obtain a copy of the License at 

8# 

9# http://www.apache.org/licenses/LICENSE-2.0 

10# 

11# Unless required by applicable law or agreed to in writing, 

12# software distributed under the License is distributed on an 

13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 

14# KIND, either express or implied. See the License for the 

15# specific language governing permissions and limitations 

16# under the License. 

17 

18from __future__ import annotations 

19 

20import re 

21from abc import ABCMeta, abstractmethod 

22from collections.abc import Sequence 

23from datetime import datetime 

24from typing import TYPE_CHECKING, Any 

25 

26from airflow.sdk._shared.dagnode.node import GenericDAGNode 

27from airflow.sdk.definitions._internal.mixins import DependencyMixin 

28 

29if TYPE_CHECKING: 

30 from airflow.sdk.definitions.dag import DAG 

31 from airflow.sdk.definitions.edges import EdgeModifier 

32 from airflow.sdk.definitions.taskgroup import TaskGroup # noqa: F401 

33 from airflow.sdk.types import Operator # noqa: F401 

34 from airflow.serialization.enums import DagAttributeTypes 

35 

36 

37KEY_REGEX = re.compile(r"^[\w.-]+$") 

38GROUP_KEY_REGEX = re.compile(r"^[\w-]+$") 

39CAMELCASE_TO_SNAKE_CASE_REGEX = re.compile(r"(?!^)([A-Z]+)") 

40 

41 

42def validate_key(k: str, max_length: int = 250): 

43 """Validate value used as a key.""" 

44 if not isinstance(k, str): 

45 raise TypeError(f"The key has to be a string and is {type(k)}:{k}") 

46 if (length := len(k)) > max_length: 

47 raise ValueError(f"The key has to be less than {max_length} characters, not {length}") 

48 if not KEY_REGEX.match(k): 

49 raise ValueError( 

50 f"The key {k!r} has to be made of alphanumeric characters, dashes, " 

51 f"dots, and underscores exclusively" 

52 ) 

53 

54 

55def validate_group_key(k: str, max_length: int = 200): 

56 """Validate value used as a group key.""" 

57 if not isinstance(k, str): 

58 raise TypeError(f"The key has to be a string and is {type(k)}:{k}") 

59 if (length := len(k)) > max_length: 

60 raise ValueError(f"The key has to be less than {max_length} characters, not {length}") 

61 if not GROUP_KEY_REGEX.match(k): 

62 raise ValueError( 

63 f"The key {k!r} has to be made of alphanumeric characters, dashes, and underscores exclusively" 

64 ) 

65 

66 

67class DAGNode(GenericDAGNode["DAG", "Operator", "TaskGroup"], DependencyMixin, metaclass=ABCMeta): 

68 """ 

69 A base class for a node in the graph of a workflow. 

70 

71 A node may be an Operator or a Task Group, either mapped or unmapped. 

72 """ 

73 

74 start_date: datetime | None 

75 end_date: datetime | None 

76 

77 @property 

78 @abstractmethod 

79 def roots(self) -> Sequence[DAGNode]: 

80 raise NotImplementedError() 

81 

82 @property 

83 @abstractmethod 

84 def leaves(self) -> Sequence[DAGNode]: 

85 raise NotImplementedError() 

86 

87 def _set_relatives( 

88 self, 

89 task_or_task_list: DependencyMixin | Sequence[DependencyMixin], 

90 upstream: bool = False, 

91 edge_modifier: EdgeModifier | None = None, 

92 ) -> None: 

93 """Set relatives for the task or task list.""" 

94 from airflow.sdk.bases.operator import BaseOperator 

95 from airflow.sdk.definitions.mappedoperator import MappedOperator 

96 

97 if not isinstance(task_or_task_list, Sequence): 

98 task_or_task_list = [task_or_task_list] 

99 

100 task_list: list[BaseOperator | MappedOperator] = [] 

101 for task_object in task_or_task_list: 

102 task_object.update_relative(self, not upstream, edge_modifier=edge_modifier) 

103 relatives = task_object.leaves if upstream else task_object.roots 

104 for task in relatives: 

105 if not isinstance(task, (BaseOperator, MappedOperator)): 

106 raise TypeError( 

107 f"Relationships can only be set between Operators; received {task.__class__.__name__}" 

108 ) 

109 task_list.append(task) 

110 

111 # relationships can only be set if the tasks share a single Dag. Tasks 

112 # without a Dag are assigned to that Dag. 

113 dags: set[DAG] = {task.dag for task in [*self.roots, *task_list] if task.has_dag() and task.dag} 

114 

115 if len(dags) > 1: 

116 raise RuntimeError(f"Tried to set relationships between tasks in more than one Dag: {dags}") 

117 if len(dags) == 1: 

118 dag = dags.pop() 

119 else: 

120 raise ValueError( 

121 "Tried to create relationships between tasks that don't have Dags yet. " 

122 f"Set the Dag for at least one task and try again: {[self, *task_list]}" 

123 ) 

124 

125 if not self.has_dag(): 

126 # If this task does not yet have a Dag, add it to the same Dag as the other task. 

127 self.dag = dag 

128 

129 for task in task_list: 

130 if dag and not task.has_dag(): 

131 # If the other task does not yet have a Dag, add it to the same Dag as this task and 

132 dag.add_task(task) # type: ignore[arg-type] 

133 if upstream: 

134 task.downstream_task_ids.add(self.node_id) 

135 self.upstream_task_ids.add(task.node_id) 

136 if edge_modifier: 

137 edge_modifier.add_edge_info(dag, task.node_id, self.node_id) 

138 else: 

139 self.downstream_task_ids.add(task.node_id) 

140 task.upstream_task_ids.add(self.node_id) 

141 if edge_modifier: 

142 edge_modifier.add_edge_info(dag, self.node_id, task.node_id) 

143 

144 def set_downstream( 

145 self, 

146 task_or_task_list: DependencyMixin | Sequence[DependencyMixin], 

147 edge_modifier: EdgeModifier | None = None, 

148 ) -> None: 

149 """Set a node (or nodes) to be directly downstream from the current node.""" 

150 self._set_relatives(task_or_task_list, upstream=False, edge_modifier=edge_modifier) 

151 

152 def set_upstream( 

153 self, 

154 task_or_task_list: DependencyMixin | Sequence[DependencyMixin], 

155 edge_modifier: EdgeModifier | None = None, 

156 ) -> None: 

157 """Set a node (or nodes) to be directly upstream from the current node.""" 

158 self._set_relatives(task_or_task_list, upstream=True, edge_modifier=edge_modifier) 

159 

160 def serialize_for_task_group(self) -> tuple[DagAttributeTypes, Any]: 

161 """Serialize a task group's content; used by TaskGroupSerialization.""" 

162 raise NotImplementedError()