Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/airflow/models/taskmixin.py: 56%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

153 statements  

1# Licensed to the Apache Software Foundation (ASF) under one 

2# or more contributor license agreements. See the NOTICE file 

3# distributed with this work for additional information 

4# regarding copyright ownership. The ASF licenses this file 

5# to you under the Apache License, Version 2.0 (the 

6# "License"); you may not use this file except in compliance 

7# with the License. You may obtain a copy of the License at 

8# 

9# http://www.apache.org/licenses/LICENSE-2.0 

10# 

11# Unless required by applicable law or agreed to in writing, 

12# software distributed under the License is distributed on an 

13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 

14# KIND, either express or implied. See the License for the 

15# specific language governing permissions and limitations 

16# under the License. 

17from __future__ import annotations 

18 

19import warnings 

20from abc import ABCMeta, abstractmethod 

21from typing import TYPE_CHECKING, Any, Iterable, Sequence 

22 

23from airflow.exceptions import AirflowException, RemovedInAirflow3Warning 

24from airflow.utils.types import NOTSET 

25 

26if TYPE_CHECKING: 

27 from logging import Logger 

28 

29 import pendulum 

30 

31 from airflow.models.baseoperator import BaseOperator 

32 from airflow.models.dag import DAG 

33 from airflow.models.operator import Operator 

34 from airflow.serialization.enums import DagAttributeTypes 

35 from airflow.utils.edgemodifier import EdgeModifier 

36 from airflow.utils.task_group import TaskGroup 

37 from airflow.utils.types import ArgNotSet 

38 

39 

40class DependencyMixin: 

41 """Mixing implementing common dependency setting methods like >> and <<.""" 

42 

43 @property 

44 def roots(self) -> Sequence[DependencyMixin]: 

45 """ 

46 List of root nodes -- ones with no upstream dependencies. 

47 

48 a.k.a. the "start" of this sub-graph 

49 """ 

50 raise NotImplementedError() 

51 

52 @property 

53 def leaves(self) -> Sequence[DependencyMixin]: 

54 """ 

55 List of leaf nodes -- ones with only upstream dependencies. 

56 

57 a.k.a. the "end" of this sub-graph 

58 """ 

59 raise NotImplementedError() 

60 

61 @abstractmethod 

62 def set_upstream( 

63 self, other: DependencyMixin | Sequence[DependencyMixin], edge_modifier: EdgeModifier | None = None 

64 ): 

65 """Set a task or a task list to be directly upstream from the current task.""" 

66 raise NotImplementedError() 

67 

68 @abstractmethod 

69 def set_downstream( 

70 self, other: DependencyMixin | Sequence[DependencyMixin], edge_modifier: EdgeModifier | None = None 

71 ): 

72 """Set a task or a task list to be directly downstream from the current task.""" 

73 raise NotImplementedError() 

74 

75 def as_setup(self) -> DependencyMixin: 

76 """Mark a task as setup task.""" 

77 raise NotImplementedError() 

78 

79 def as_teardown( 

80 self, 

81 *, 

82 setups: BaseOperator | Iterable[BaseOperator] | ArgNotSet = NOTSET, 

83 on_failure_fail_dagrun=NOTSET, 

84 ) -> DependencyMixin: 

85 """Mark a task as teardown and set its setups as direct relatives.""" 

86 raise NotImplementedError() 

87 

88 def update_relative( 

89 self, other: DependencyMixin, upstream: bool = True, edge_modifier: EdgeModifier | None = None 

90 ) -> None: 

91 """ 

92 Update relationship information about another TaskMixin. Default is no-op. 

93 

94 Override if necessary. 

95 """ 

96 

97 def __lshift__(self, other: DependencyMixin | Sequence[DependencyMixin]): 

98 """Implement Task << Task.""" 

99 self.set_upstream(other) 

100 return other 

101 

102 def __rshift__(self, other: DependencyMixin | Sequence[DependencyMixin]): 

103 """Implement Task >> Task.""" 

104 self.set_downstream(other) 

105 return other 

106 

107 def __rrshift__(self, other: DependencyMixin | Sequence[DependencyMixin]): 

108 """Implement Task >> [Task] because list don't have __rshift__ operators.""" 

109 self.__lshift__(other) 

110 return self 

111 

112 def __rlshift__(self, other: DependencyMixin | Sequence[DependencyMixin]): 

113 """Implement Task << [Task] because list don't have __lshift__ operators.""" 

114 self.__rshift__(other) 

115 return self 

116 

117 @classmethod 

118 def _iter_references(cls, obj: Any) -> Iterable[tuple[DependencyMixin, str]]: 

119 from airflow.models.baseoperator import AbstractOperator 

120 from airflow.utils.mixins import ResolveMixin 

121 

122 if isinstance(obj, AbstractOperator): 

123 yield obj, "operator" 

124 elif isinstance(obj, ResolveMixin): 

125 yield from obj.iter_references() 

126 elif isinstance(obj, Sequence): 

127 for o in obj: 

128 yield from cls._iter_references(o) 

129 

130 

131class TaskMixin(DependencyMixin): 

132 """Mixin to provide task-related things. 

133 

134 :meta private: 

135 """ 

136 

137 def __init_subclass__(cls) -> None: 

138 warnings.warn( 

139 f"TaskMixin has been renamed to DependencyMixin, please update {cls.__name__}", 

140 category=RemovedInAirflow3Warning, 

141 stacklevel=2, 

142 ) 

143 return super().__init_subclass__() 

144 

145 

146class DAGNode(DependencyMixin, metaclass=ABCMeta): 

147 """ 

148 A base class for a node in the graph of a workflow. 

149 

150 A node may be an Operator or a Task Group, either mapped or unmapped. 

151 """ 

152 

153 dag: DAG | None = None 

154 task_group: TaskGroup | None = None 

155 """The task_group that contains this node""" 

156 

157 @property 

158 @abstractmethod 

159 def node_id(self) -> str: 

160 raise NotImplementedError() 

161 

162 @property 

163 def label(self) -> str | None: 

164 tg = self.task_group 

165 if tg and tg.node_id and tg.prefix_group_id: 

166 # "task_group_id.task_id" -> "task_id" 

167 return self.node_id[len(tg.node_id) + 1 :] 

168 return self.node_id 

169 

170 start_date: pendulum.DateTime | None 

171 end_date: pendulum.DateTime | None 

172 upstream_task_ids: set[str] 

173 downstream_task_ids: set[str] 

174 

175 def has_dag(self) -> bool: 

176 return self.dag is not None 

177 

178 @property 

179 def dag_id(self) -> str: 

180 """Returns dag id if it has one or an adhoc/meaningless ID.""" 

181 if self.dag: 

182 return self.dag.dag_id 

183 return "_in_memory_dag_" 

184 

185 @property 

186 def log(self) -> Logger: 

187 raise NotImplementedError() 

188 

189 @property 

190 @abstractmethod 

191 def roots(self) -> Sequence[DAGNode]: 

192 raise NotImplementedError() 

193 

194 @property 

195 @abstractmethod 

196 def leaves(self) -> Sequence[DAGNode]: 

197 raise NotImplementedError() 

198 

199 def _set_relatives( 

200 self, 

201 task_or_task_list: DependencyMixin | Sequence[DependencyMixin], 

202 upstream: bool = False, 

203 edge_modifier: EdgeModifier | None = None, 

204 ) -> None: 

205 """Set relatives for the task or task list.""" 

206 from airflow.models.baseoperator import BaseOperator 

207 from airflow.models.mappedoperator import MappedOperator 

208 

209 if not isinstance(task_or_task_list, Sequence): 

210 task_or_task_list = [task_or_task_list] 

211 

212 task_list: list[Operator] = [] 

213 for task_object in task_or_task_list: 

214 task_object.update_relative(self, not upstream, edge_modifier=edge_modifier) 

215 relatives = task_object.leaves if upstream else task_object.roots 

216 for task in relatives: 

217 if not isinstance(task, (BaseOperator, MappedOperator)): 

218 raise AirflowException( 

219 f"Relationships can only be set between Operators; received {task.__class__.__name__}" 

220 ) 

221 task_list.append(task) 

222 

223 # relationships can only be set if the tasks share a single DAG. Tasks 

224 # without a DAG are assigned to that DAG. 

225 dags: set[DAG] = {task.dag for task in [*self.roots, *task_list] if task.has_dag() and task.dag} 

226 

227 if len(dags) > 1: 

228 raise AirflowException(f"Tried to set relationships between tasks in more than one DAG: {dags}") 

229 elif len(dags) == 1: 

230 dag = dags.pop() 

231 else: 

232 raise AirflowException( 

233 f"Tried to create relationships between tasks that don't have DAGs yet. " 

234 f"Set the DAG for at least one task and try again: {[self, *task_list]}" 

235 ) 

236 

237 if not self.has_dag(): 

238 # If this task does not yet have a dag, add it to the same dag as the other task. 

239 self.dag = dag 

240 

241 for task in task_list: 

242 if dag and not task.has_dag(): 

243 # If the other task does not yet have a dag, add it to the same dag as this task and 

244 dag.add_task(task) 

245 if upstream: 

246 task.downstream_task_ids.add(self.node_id) 

247 self.upstream_task_ids.add(task.node_id) 

248 if edge_modifier: 

249 edge_modifier.add_edge_info(self.dag, task.node_id, self.node_id) 

250 else: 

251 self.downstream_task_ids.add(task.node_id) 

252 task.upstream_task_ids.add(self.node_id) 

253 if edge_modifier: 

254 edge_modifier.add_edge_info(self.dag, self.node_id, task.node_id) 

255 

256 def set_downstream( 

257 self, 

258 task_or_task_list: DependencyMixin | Sequence[DependencyMixin], 

259 edge_modifier: EdgeModifier | None = None, 

260 ) -> None: 

261 """Set a node (or nodes) to be directly downstream from the current node.""" 

262 self._set_relatives(task_or_task_list, upstream=False, edge_modifier=edge_modifier) 

263 

264 def set_upstream( 

265 self, 

266 task_or_task_list: DependencyMixin | Sequence[DependencyMixin], 

267 edge_modifier: EdgeModifier | None = None, 

268 ) -> None: 

269 """Set a node (or nodes) to be directly upstream from the current node.""" 

270 self._set_relatives(task_or_task_list, upstream=True, edge_modifier=edge_modifier) 

271 

272 @property 

273 def downstream_list(self) -> Iterable[Operator]: 

274 """List of nodes directly downstream.""" 

275 if not self.dag: 

276 raise AirflowException(f"Operator {self} has not been assigned to a DAG yet") 

277 return [self.dag.get_task(tid) for tid in self.downstream_task_ids] 

278 

279 @property 

280 def upstream_list(self) -> Iterable[Operator]: 

281 """List of nodes directly upstream.""" 

282 if not self.dag: 

283 raise AirflowException(f"Operator {self} has not been assigned to a DAG yet") 

284 return [self.dag.get_task(tid) for tid in self.upstream_task_ids] 

285 

286 def get_direct_relative_ids(self, upstream: bool = False) -> set[str]: 

287 """Get set of the direct relative ids to the current task, upstream or downstream.""" 

288 if upstream: 

289 return self.upstream_task_ids 

290 else: 

291 return self.downstream_task_ids 

292 

293 def get_direct_relatives(self, upstream: bool = False) -> Iterable[DAGNode]: 

294 """Get list of the direct relatives to the current task, upstream or downstream.""" 

295 if upstream: 

296 return self.upstream_list 

297 else: 

298 return self.downstream_list 

299 

300 def serialize_for_task_group(self) -> tuple[DagAttributeTypes, Any]: 

301 """Serialize a task group's content; used by TaskGroupSerialization.""" 

302 raise NotImplementedError()