Coverage for /pythoncovmergedfiles/medio/medio/src/airflow/airflow/models/taskmixin.py: 61%

140 statements  

« prev     ^ index     » next       coverage.py v7.0.1, created at 2022-12-25 06:11 +0000

1# Licensed to the Apache Software Foundation (ASF) under one 

2# or more contributor license agreements. See the NOTICE file 

3# distributed with this work for additional information 

4# regarding copyright ownership. The ASF licenses this file 

5# to you under the Apache License, Version 2.0 (the 

6# "License"); you may not use this file except in compliance 

7# with the License. You may obtain a copy of the License at 

8# 

9# http://www.apache.org/licenses/LICENSE-2.0 

10# 

11# Unless required by applicable law or agreed to in writing, 

12# software distributed under the License is distributed on an 

13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 

14# KIND, either express or implied. See the License for the 

15# specific language governing permissions and limitations 

16# under the License. 

17from __future__ import annotations 

18 

19import warnings 

20from abc import ABCMeta, abstractmethod 

21from typing import TYPE_CHECKING, Any, Iterable, Sequence 

22 

23import pendulum 

24 

25from airflow.exceptions import AirflowException, RemovedInAirflow3Warning 

26from airflow.serialization.enums import DagAttributeTypes 

27 

28if TYPE_CHECKING: 

29 from logging import Logger 

30 

31 from airflow.models.dag import DAG 

32 from airflow.models.operator import Operator 

33 from airflow.utils.edgemodifier import EdgeModifier 

34 from airflow.utils.task_group import TaskGroup 

35 

36 

37class DependencyMixin: 

38 """Mixing implementing common dependency setting methods methods like >> and <<.""" 

39 

40 @property 

41 def roots(self) -> Sequence[DependencyMixin]: 

42 """ 

43 List of root nodes -- ones with no upstream dependencies. 

44 

45 a.k.a. the "start" of this sub-graph 

46 """ 

47 raise NotImplementedError() 

48 

49 @property 

50 def leaves(self) -> Sequence[DependencyMixin]: 

51 """ 

52 List of leaf nodes -- ones with only upstream dependencies. 

53 

54 a.k.a. the "end" of this sub-graph 

55 """ 

56 raise NotImplementedError() 

57 

58 @abstractmethod 

59 def set_upstream(self, other: DependencyMixin | Sequence[DependencyMixin]): 

60 """Set a task or a task list to be directly upstream from the current task.""" 

61 raise NotImplementedError() 

62 

63 @abstractmethod 

64 def set_downstream(self, other: DependencyMixin | Sequence[DependencyMixin]): 

65 """Set a task or a task list to be directly downstream from the current task.""" 

66 raise NotImplementedError() 

67 

68 def update_relative(self, other: DependencyMixin, upstream=True) -> None: 

69 """ 

70 Update relationship information about another TaskMixin. Default is no-op. 

71 Override if necessary. 

72 """ 

73 

74 def __lshift__(self, other: DependencyMixin | Sequence[DependencyMixin]): 

75 """Implements Task << Task""" 

76 self.set_upstream(other) 

77 return other 

78 

79 def __rshift__(self, other: DependencyMixin | Sequence[DependencyMixin]): 

80 """Implements Task >> Task""" 

81 self.set_downstream(other) 

82 return other 

83 

84 def __rrshift__(self, other: DependencyMixin | Sequence[DependencyMixin]): 

85 """Called for Task >> [Task] because list don't have __rshift__ operators.""" 

86 self.__lshift__(other) 

87 return self 

88 

89 def __rlshift__(self, other: DependencyMixin | Sequence[DependencyMixin]): 

90 """Called for Task << [Task] because list don't have __lshift__ operators.""" 

91 self.__rshift__(other) 

92 return self 

93 

94 

95class TaskMixin(DependencyMixin): 

96 """:meta private:""" 

97 

98 def __init_subclass__(cls) -> None: 

99 warnings.warn( 

100 f"TaskMixin has been renamed to DependencyMixin, please update {cls.__name__}", 

101 category=RemovedInAirflow3Warning, 

102 stacklevel=2, 

103 ) 

104 return super().__init_subclass__() 

105 

106 

107class DAGNode(DependencyMixin, metaclass=ABCMeta): 

108 """ 

109 A base class for a node in the graph of a workflow -- an Operator or a Task Group, either mapped or 

110 unmapped. 

111 """ 

112 

113 dag: DAG | None = None 

114 task_group: TaskGroup | None = None 

115 """The task_group that contains this node""" 

116 

117 @property 

118 @abstractmethod 

119 def node_id(self) -> str: 

120 raise NotImplementedError() 

121 

122 @property 

123 def label(self) -> str | None: 

124 tg = self.task_group 

125 if tg and tg.node_id and tg.prefix_group_id: 

126 # "task_group_id.task_id" -> "task_id" 

127 return self.node_id[len(tg.node_id) + 1 :] 

128 return self.node_id 

129 

130 start_date: pendulum.DateTime | None 

131 end_date: pendulum.DateTime | None 

132 upstream_task_ids: set[str] 

133 downstream_task_ids: set[str] 

134 

135 def has_dag(self) -> bool: 

136 return self.dag is not None 

137 

138 @property 

139 def dag_id(self) -> str: 

140 """Returns dag id if it has one or an adhoc/meaningless ID""" 

141 if self.dag: 

142 return self.dag.dag_id 

143 return "_in_memory_dag_" 

144 

145 @property 

146 def log(self) -> Logger: 

147 raise NotImplementedError() 

148 

149 @property 

150 @abstractmethod 

151 def roots(self) -> Sequence[DAGNode]: 

152 raise NotImplementedError() 

153 

154 @property 

155 @abstractmethod 

156 def leaves(self) -> Sequence[DAGNode]: 

157 raise NotImplementedError() 

158 

159 def _set_relatives( 

160 self, 

161 task_or_task_list: DependencyMixin | Sequence[DependencyMixin], 

162 upstream: bool = False, 

163 edge_modifier: EdgeModifier | None = None, 

164 ) -> None: 

165 """Sets relatives for the task or task list.""" 

166 from airflow.models.baseoperator import BaseOperator 

167 from airflow.models.mappedoperator import MappedOperator 

168 from airflow.models.operator import Operator 

169 

170 if not isinstance(task_or_task_list, Sequence): 

171 task_or_task_list = [task_or_task_list] 

172 

173 task_list: list[Operator] = [] 

174 for task_object in task_or_task_list: 

175 task_object.update_relative(self, not upstream) 

176 relatives = task_object.leaves if upstream else task_object.roots 

177 for task in relatives: 

178 if not isinstance(task, (BaseOperator, MappedOperator)): 

179 raise AirflowException( 

180 f"Relationships can only be set between Operators; received {task.__class__.__name__}" 

181 ) 

182 task_list.append(task) 

183 

184 # relationships can only be set if the tasks share a single DAG. Tasks 

185 # without a DAG are assigned to that DAG. 

186 dags: set[DAG] = {task.dag for task in [*self.roots, *task_list] if task.has_dag() and task.dag} 

187 

188 if len(dags) > 1: 

189 raise AirflowException(f"Tried to set relationships between tasks in more than one DAG: {dags}") 

190 elif len(dags) == 1: 

191 dag = dags.pop() 

192 else: 

193 raise AirflowException( 

194 f"Tried to create relationships between tasks that don't have DAGs yet. " 

195 f"Set the DAG for at least one task and try again: {[self, *task_list]}" 

196 ) 

197 

198 if not self.has_dag(): 

199 # If this task does not yet have a dag, add it to the same dag as the other task. 

200 self.dag = dag 

201 

202 def add_only_new(obj, item_set: set[str], item: str) -> None: 

203 """Adds only new items to item set""" 

204 if item in item_set: 

205 self.log.warning("Dependency %s, %s already registered for DAG: %s", obj, item, dag.dag_id) 

206 else: 

207 item_set.add(item) 

208 

209 for task in task_list: 

210 if dag and not task.has_dag(): 

211 # If the other task does not yet have a dag, add it to the same dag as this task and 

212 dag.add_task(task) 

213 if upstream: 

214 add_only_new(task, task.downstream_task_ids, self.node_id) 

215 add_only_new(self, self.upstream_task_ids, task.node_id) 

216 if edge_modifier: 

217 edge_modifier.add_edge_info(self.dag, task.node_id, self.node_id) 

218 else: 

219 add_only_new(self, self.downstream_task_ids, task.node_id) 

220 add_only_new(task, task.upstream_task_ids, self.node_id) 

221 if edge_modifier: 

222 edge_modifier.add_edge_info(self.dag, self.node_id, task.node_id) 

223 

224 def set_downstream( 

225 self, 

226 task_or_task_list: DependencyMixin | Sequence[DependencyMixin], 

227 edge_modifier: EdgeModifier | None = None, 

228 ) -> None: 

229 """Set a node (or nodes) to be directly downstream from the current node.""" 

230 self._set_relatives(task_or_task_list, upstream=False, edge_modifier=edge_modifier) 

231 

232 def set_upstream( 

233 self, 

234 task_or_task_list: DependencyMixin | Sequence[DependencyMixin], 

235 edge_modifier: EdgeModifier | None = None, 

236 ) -> None: 

237 """Set a node (or nodes) to be directly upstream from the current node.""" 

238 self._set_relatives(task_or_task_list, upstream=True, edge_modifier=edge_modifier) 

239 

240 @property 

241 def downstream_list(self) -> Iterable[Operator]: 

242 """List of nodes directly downstream""" 

243 if not self.dag: 

244 raise AirflowException(f"Operator {self} has not been assigned to a DAG yet") 

245 return [self.dag.get_task(tid) for tid in self.downstream_task_ids] 

246 

247 @property 

248 def upstream_list(self) -> Iterable[Operator]: 

249 """List of nodes directly upstream""" 

250 if not self.dag: 

251 raise AirflowException(f"Operator {self} has not been assigned to a DAG yet") 

252 return [self.dag.get_task(tid) for tid in self.upstream_task_ids] 

253 

254 def get_direct_relative_ids(self, upstream: bool = False) -> set[str]: 

255 """ 

256 Get set of the direct relative ids to the current task, upstream or 

257 downstream. 

258 """ 

259 if upstream: 

260 return self.upstream_task_ids 

261 else: 

262 return self.downstream_task_ids 

263 

264 def get_direct_relatives(self, upstream: bool = False) -> Iterable[DAGNode]: 

265 """ 

266 Get list of the direct relatives to the current task, upstream or 

267 downstream. 

268 """ 

269 if upstream: 

270 return self.upstream_list 

271 else: 

272 return self.downstream_list 

273 

274 def serialize_for_task_group(self) -> tuple[DagAttributeTypes, Any]: 

275 """This is used by TaskGroupSerialization to serialize a task group's content.""" 

276 raise NotImplementedError()