Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/airflow/models/taskmixin.py: 61%

136 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-06-07 06:35 +0000

1# Licensed to the Apache Software Foundation (ASF) under one 

2# or more contributor license agreements. See the NOTICE file 

3# distributed with this work for additional information 

4# regarding copyright ownership. The ASF licenses this file 

5# to you under the Apache License, Version 2.0 (the 

6# "License"); you may not use this file except in compliance 

7# with the License. You may obtain a copy of the License at 

8# 

9# http://www.apache.org/licenses/LICENSE-2.0 

10# 

11# Unless required by applicable law or agreed to in writing, 

12# software distributed under the License is distributed on an 

13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 

14# KIND, either express or implied. See the License for the 

15# specific language governing permissions and limitations 

16# under the License. 

17from __future__ import annotations 

18 

19import warnings 

20from abc import ABCMeta, abstractmethod 

21from typing import TYPE_CHECKING, Any, Iterable, Sequence 

22 

23import pendulum 

24 

25from airflow.exceptions import AirflowException, RemovedInAirflow3Warning 

26from airflow.serialization.enums import DagAttributeTypes 

27 

28if TYPE_CHECKING: 

29 from logging import Logger 

30 

31 from airflow.models.dag import DAG 

32 from airflow.models.operator import Operator 

33 from airflow.utils.edgemodifier import EdgeModifier 

34 from airflow.utils.task_group import TaskGroup 

35 

36 

37class DependencyMixin: 

38 """Mixing implementing common dependency setting methods methods like >> and <<.""" 

39 

40 @property 

41 def roots(self) -> Sequence[DependencyMixin]: 

42 """ 

43 List of root nodes -- ones with no upstream dependencies. 

44 

45 a.k.a. the "start" of this sub-graph 

46 """ 

47 raise NotImplementedError() 

48 

49 @property 

50 def leaves(self) -> Sequence[DependencyMixin]: 

51 """ 

52 List of leaf nodes -- ones with only upstream dependencies. 

53 

54 a.k.a. the "end" of this sub-graph 

55 """ 

56 raise NotImplementedError() 

57 

58 @abstractmethod 

59 def set_upstream( 

60 self, other: DependencyMixin | Sequence[DependencyMixin], edge_modifier: EdgeModifier | None = None 

61 ): 

62 """Set a task or a task list to be directly upstream from the current task.""" 

63 raise NotImplementedError() 

64 

65 @abstractmethod 

66 def set_downstream( 

67 self, other: DependencyMixin | Sequence[DependencyMixin], edge_modifier: EdgeModifier | None = None 

68 ): 

69 """Set a task or a task list to be directly downstream from the current task.""" 

70 raise NotImplementedError() 

71 

72 def update_relative( 

73 self, other: DependencyMixin, upstream: bool = True, edge_modifier: EdgeModifier | None = None 

74 ) -> None: 

75 """ 

76 Update relationship information about another TaskMixin. Default is no-op. 

77 Override if necessary. 

78 """ 

79 

80 def __lshift__(self, other: DependencyMixin | Sequence[DependencyMixin]): 

81 """Implements Task << Task.""" 

82 self.set_upstream(other) 

83 return other 

84 

85 def __rshift__(self, other: DependencyMixin | Sequence[DependencyMixin]): 

86 """Implements Task >> Task.""" 

87 self.set_downstream(other) 

88 return other 

89 

90 def __rrshift__(self, other: DependencyMixin | Sequence[DependencyMixin]): 

91 """Called for Task >> [Task] because list don't have __rshift__ operators.""" 

92 self.__lshift__(other) 

93 return self 

94 

95 def __rlshift__(self, other: DependencyMixin | Sequence[DependencyMixin]): 

96 """Called for Task << [Task] because list don't have __lshift__ operators.""" 

97 self.__rshift__(other) 

98 return self 

99 

100 

101class TaskMixin(DependencyMixin): 

102 """Mixin to provide task-related things. 

103 

104 :meta private: 

105 """ 

106 

107 def __init_subclass__(cls) -> None: 

108 warnings.warn( 

109 f"TaskMixin has been renamed to DependencyMixin, please update {cls.__name__}", 

110 category=RemovedInAirflow3Warning, 

111 stacklevel=2, 

112 ) 

113 return super().__init_subclass__() 

114 

115 

116class DAGNode(DependencyMixin, metaclass=ABCMeta): 

117 """ 

118 A base class for a node in the graph of a workflow -- an Operator or a Task Group, either mapped or 

119 unmapped. 

120 """ 

121 

122 dag: DAG | None = None 

123 task_group: TaskGroup | None = None 

124 """The task_group that contains this node""" 

125 

126 @property 

127 @abstractmethod 

128 def node_id(self) -> str: 

129 raise NotImplementedError() 

130 

131 @property 

132 def label(self) -> str | None: 

133 tg = self.task_group 

134 if tg and tg.node_id and tg.prefix_group_id: 

135 # "task_group_id.task_id" -> "task_id" 

136 return self.node_id[len(tg.node_id) + 1 :] 

137 return self.node_id 

138 

139 start_date: pendulum.DateTime | None 

140 end_date: pendulum.DateTime | None 

141 upstream_task_ids: set[str] 

142 downstream_task_ids: set[str] 

143 

144 def has_dag(self) -> bool: 

145 return self.dag is not None 

146 

147 @property 

148 def dag_id(self) -> str: 

149 """Returns dag id if it has one or an adhoc/meaningless ID.""" 

150 if self.dag: 

151 return self.dag.dag_id 

152 return "_in_memory_dag_" 

153 

154 @property 

155 def log(self) -> Logger: 

156 raise NotImplementedError() 

157 

158 @property 

159 @abstractmethod 

160 def roots(self) -> Sequence[DAGNode]: 

161 raise NotImplementedError() 

162 

163 @property 

164 @abstractmethod 

165 def leaves(self) -> Sequence[DAGNode]: 

166 raise NotImplementedError() 

167 

168 def _set_relatives( 

169 self, 

170 task_or_task_list: DependencyMixin | Sequence[DependencyMixin], 

171 upstream: bool = False, 

172 edge_modifier: EdgeModifier | None = None, 

173 ) -> None: 

174 """Sets relatives for the task or task list.""" 

175 from airflow.models.baseoperator import BaseOperator 

176 from airflow.models.mappedoperator import MappedOperator 

177 from airflow.models.operator import Operator 

178 

179 if not isinstance(task_or_task_list, Sequence): 

180 task_or_task_list = [task_or_task_list] 

181 

182 task_list: list[Operator] = [] 

183 for task_object in task_or_task_list: 

184 task_object.update_relative(self, not upstream, edge_modifier=edge_modifier) 

185 relatives = task_object.leaves if upstream else task_object.roots 

186 for task in relatives: 

187 if not isinstance(task, (BaseOperator, MappedOperator)): 

188 raise AirflowException( 

189 f"Relationships can only be set between Operators; received {task.__class__.__name__}" 

190 ) 

191 task_list.append(task) 

192 

193 # relationships can only be set if the tasks share a single DAG. Tasks 

194 # without a DAG are assigned to that DAG. 

195 dags: set[DAG] = {task.dag for task in [*self.roots, *task_list] if task.has_dag() and task.dag} 

196 

197 if len(dags) > 1: 

198 raise AirflowException(f"Tried to set relationships between tasks in more than one DAG: {dags}") 

199 elif len(dags) == 1: 

200 dag = dags.pop() 

201 else: 

202 raise AirflowException( 

203 f"Tried to create relationships between tasks that don't have DAGs yet. " 

204 f"Set the DAG for at least one task and try again: {[self, *task_list]}" 

205 ) 

206 

207 if not self.has_dag(): 

208 # If this task does not yet have a dag, add it to the same dag as the other task. 

209 self.dag = dag 

210 

211 for task in task_list: 

212 if dag and not task.has_dag(): 

213 # If the other task does not yet have a dag, add it to the same dag as this task and 

214 dag.add_task(task) 

215 if upstream: 

216 task.downstream_task_ids.add(self.node_id) 

217 self.upstream_task_ids.add(task.node_id) 

218 if edge_modifier: 

219 edge_modifier.add_edge_info(self.dag, task.node_id, self.node_id) 

220 else: 

221 self.downstream_task_ids.add(task.node_id) 

222 task.upstream_task_ids.add(self.node_id) 

223 if edge_modifier: 

224 edge_modifier.add_edge_info(self.dag, self.node_id, task.node_id) 

225 

226 def set_downstream( 

227 self, 

228 task_or_task_list: DependencyMixin | Sequence[DependencyMixin], 

229 edge_modifier: EdgeModifier | None = None, 

230 ) -> None: 

231 """Set a node (or nodes) to be directly downstream from the current node.""" 

232 self._set_relatives(task_or_task_list, upstream=False, edge_modifier=edge_modifier) 

233 

234 def set_upstream( 

235 self, 

236 task_or_task_list: DependencyMixin | Sequence[DependencyMixin], 

237 edge_modifier: EdgeModifier | None = None, 

238 ) -> None: 

239 """Set a node (or nodes) to be directly upstream from the current node.""" 

240 self._set_relatives(task_or_task_list, upstream=True, edge_modifier=edge_modifier) 

241 

242 @property 

243 def downstream_list(self) -> Iterable[Operator]: 

244 """List of nodes directly downstream.""" 

245 if not self.dag: 

246 raise AirflowException(f"Operator {self} has not been assigned to a DAG yet") 

247 return [self.dag.get_task(tid) for tid in self.downstream_task_ids] 

248 

249 @property 

250 def upstream_list(self) -> Iterable[Operator]: 

251 """List of nodes directly upstream.""" 

252 if not self.dag: 

253 raise AirflowException(f"Operator {self} has not been assigned to a DAG yet") 

254 return [self.dag.get_task(tid) for tid in self.upstream_task_ids] 

255 

256 def get_direct_relative_ids(self, upstream: bool = False) -> set[str]: 

257 """ 

258 Get set of the direct relative ids to the current task, upstream or 

259 downstream. 

260 """ 

261 if upstream: 

262 return self.upstream_task_ids 

263 else: 

264 return self.downstream_task_ids 

265 

266 def get_direct_relatives(self, upstream: bool = False) -> Iterable[DAGNode]: 

267 """ 

268 Get list of the direct relatives to the current task, upstream or 

269 downstream. 

270 """ 

271 if upstream: 

272 return self.upstream_list 

273 else: 

274 return self.downstream_list 

275 

276 def serialize_for_task_group(self) -> tuple[DagAttributeTypes, Any]: 

277 """This is used by TaskGroupSerialization to serialize a task group's content.""" 

278 raise NotImplementedError()