Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/airflow/sdk/definitions/_internal/abstractoperator.py: 46%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

183 statements  

1# 

2# Licensed to the Apache Software Foundation (ASF) under one 

3# or more contributor license agreements. See the NOTICE file 

4# distributed with this work for additional information 

5# regarding copyright ownership. The ASF licenses this file 

6# to you under the Apache License, Version 2.0 (the 

7# "License"); you may not use this file except in compliance 

8# with the License. You may obtain a copy of the License at 

9# 

10# http://www.apache.org/licenses/LICENSE-2.0 

11# 

12# Unless required by applicable law or agreed to in writing, 

13# software distributed under the License is distributed on an 

14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 

15# KIND, either express or implied. See the License for the 

16# specific language governing permissions and limitations 

17# under the License. 

18from __future__ import annotations 

19 

20import datetime 

21import logging 

22from abc import abstractmethod 

23from collections.abc import ( 

24 Callable, 

25 Collection, 

26 Iterable, 

27 Iterator, 

28) 

29from typing import TYPE_CHECKING, Any, ClassVar, TypeAlias 

30 

31from airflow.sdk import TriggerRule, WeightRule 

32from airflow.sdk.configuration import conf 

33from airflow.sdk.definitions._internal.mixins import DependencyMixin 

34from airflow.sdk.definitions._internal.node import DAGNode 

35from airflow.sdk.definitions._internal.setup_teardown import SetupTeardownContext 

36from airflow.sdk.definitions._internal.templater import Templater 

37from airflow.sdk.definitions.context import Context 

38 

39if TYPE_CHECKING: 

40 import jinja2 

41 

42 from airflow.sdk.bases.operator import BaseOperator 

43 from airflow.sdk.bases.operatorlink import BaseOperatorLink 

44 from airflow.sdk.definitions.dag import DAG 

45 from airflow.sdk.definitions.mappedoperator import MappedOperator 

46 from airflow.sdk.definitions.taskgroup import MappedTaskGroup 

47 

48TaskStateChangeCallback = Callable[[Context], None] 

49TaskStateChangeCallbackAttrType: TypeAlias = TaskStateChangeCallback | list[TaskStateChangeCallback] | None 

50 

51DEFAULT_OWNER: str = conf.get_mandatory_value("operators", "default_owner") 

52DEFAULT_POOL_SLOTS: int = 1 

53DEFAULT_POOL_NAME = "default_pool" 

54DEFAULT_PRIORITY_WEIGHT: int = 1 

55# Databases do not support arbitrary precision integers, so we need to limit the range of priority weights. 

56# postgres: -2147483648 to +2147483647 (see https://www.postgresql.org/docs/current/datatype-numeric.html) 

57# mysql: -2147483648 to +2147483647 (see https://dev.mysql.com/doc/refman/8.4/en/integer-types.html) 

58# sqlite: -9223372036854775808 to +9223372036854775807 (see https://sqlite.org/datatype3.html) 

59MINIMUM_PRIORITY_WEIGHT: int = -2147483648 

60MAXIMUM_PRIORITY_WEIGHT: int = 2147483647 

61DEFAULT_EXECUTOR: str | None = None 

62DEFAULT_QUEUE: str = conf.get_mandatory_value("operators", "default_queue") 

63DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST: bool = False 

64DEFAULT_WAIT_FOR_PAST_DEPENDS_BEFORE_SKIPPING: bool = False 

65DEFAULT_RETRIES: int = conf.getint("core", "default_task_retries", fallback=0) 

66DEFAULT_RETRY_DELAY: datetime.timedelta = datetime.timedelta( 

67 seconds=conf.getint("core", "default_task_retry_delay", fallback=300) 

68) 

69MAX_RETRY_DELAY: int = conf.getint("core", "max_task_retry_delay", fallback=24 * 60 * 60) 

70 

71# TODO: Task-SDK -- these defaults should be overridable from the Airflow config 

72DEFAULT_TRIGGER_RULE: TriggerRule = TriggerRule.ALL_SUCCESS 

73DEFAULT_WEIGHT_RULE: WeightRule = WeightRule( 

74 conf.get("core", "default_task_weight_rule", fallback=WeightRule.DOWNSTREAM) 

75) 

76DEFAULT_TASK_EXECUTION_TIMEOUT: datetime.timedelta | None = conf.gettimedelta( 

77 "core", "default_task_execution_timeout" 

78) 

79 

80log = logging.getLogger(__name__) 

81 

82 

83class AbstractOperator(Templater, DAGNode): 

84 """ 

85 Common implementation for operators, including unmapped and mapped. 

86 

87 This base class is more about sharing implementations, not defining a common 

88 interface. Unfortunately it's difficult to use this as the common base class 

89 for typing due to BaseOperator carrying too much historical baggage. 

90 

91 The union type ``from airflow.models.operator import Operator`` is easier 

92 to use for typing purposes. 

93 

94 :meta private: 

95 """ 

96 

97 operator_class: type[BaseOperator] | dict[str, Any] 

98 

99 priority_weight: int 

100 

101 # Defines the operator level extra links. 

102 operator_extra_links: Collection[BaseOperatorLink] = () 

103 

104 owner: str 

105 task_id: str 

106 

107 outlets: list 

108 inlets: list 

109 

110 trigger_rule: TriggerRule 

111 _needs_expansion: bool | None = None 

112 _on_failure_fail_dagrun = False 

113 is_setup: bool = False 

114 is_teardown: bool = False 

115 

116 HIDE_ATTRS_FROM_UI: ClassVar[frozenset[str]] = frozenset( 

117 ( 

118 "log", 

119 "dag", # We show dag_id, don't need to show this too 

120 "node_id", # Duplicates task_id 

121 "task_group", # Doesn't have a useful repr, no point showing in UI 

122 "inherits_from_empty_operator", # impl detail 

123 "inherits_from_skipmixin", # impl detail 

124 # Decide whether to start task execution from triggerer 

125 "start_trigger_args", 

126 "start_from_trigger", 

127 # For compatibility with TG, for operators these are just the current task, no point showing 

128 "roots", 

129 "leaves", 

130 # These lists are already shown via *_task_ids 

131 "upstream_list", 

132 "downstream_list", 

133 # Not useful, implementation detail, already shown elsewhere 

134 "global_operator_extra_link_dict", 

135 "operator_extra_link_dict", 

136 ) 

137 ) 

138 

139 @property 

140 def task_type(self) -> str: 

141 raise NotImplementedError() 

142 

143 @property 

144 def operator_name(self) -> str: 

145 raise NotImplementedError() 

146 

147 _is_sensor: bool = False 

148 _is_mapped: bool = False 

149 _can_skip_downstream: bool = False 

150 

151 @property 

152 def dag_id(self) -> str: 

153 """Returns dag id if it has one or an adhoc + owner.""" 

154 dag = self.get_dag() 

155 if dag: 

156 return dag.dag_id 

157 return f"adhoc_{self.owner}" 

158 

159 @property 

160 def node_id(self) -> str: 

161 return self.task_id 

162 

163 @property 

164 @abstractmethod 

165 def task_display_name(self) -> str: ... 

166 

167 @property 

168 def is_mapped(self): 

169 return self._is_mapped 

170 

171 @property 

172 def label(self) -> str | None: 

173 if self.task_display_name and self.task_display_name != self.task_id: 

174 return self.task_display_name 

175 # Prefix handling if no display is given is cloned from taskmixin for compatibility 

176 tg = self.task_group 

177 if tg and tg.node_id and tg.prefix_group_id: 

178 # "task_group_id.task_id" -> "task_id" 

179 return self.task_id[len(tg.node_id) + 1 :] 

180 return self.task_id 

181 

182 @property 

183 def on_failure_fail_dagrun(self): 

184 """ 

185 Whether the operator should fail the dagrun on failure. 

186 

187 :meta private: 

188 """ 

189 return self._on_failure_fail_dagrun 

190 

191 @on_failure_fail_dagrun.setter 

192 def on_failure_fail_dagrun(self, value): 

193 """ 

194 Setter for on_failure_fail_dagrun property. 

195 

196 :meta private: 

197 """ 

198 if value is True and self.is_teardown is not True: 

199 raise ValueError( 

200 f"Cannot set task on_failure_fail_dagrun for " 

201 f"'{self.task_id}' because it is not a teardown task." 

202 ) 

203 self._on_failure_fail_dagrun = value 

204 

205 @property 

206 def inherits_from_empty_operator(self): 

207 """Used to determine if an Operator is inherited from EmptyOperator.""" 

208 # This looks like `isinstance(self, EmptyOperator) would work, but this also 

209 # needs to cope when `self` is a Serialized instance of a EmptyOperator or one 

210 # of its subclasses (which don't inherit from anything but BaseOperator). 

211 return getattr(self, "_is_empty", False) 

212 

213 @property 

214 def inherits_from_skipmixin(self): 

215 """Used to determine if an Operator is inherited from SkipMixin or its subclasses (e.g., BranchMixin).""" 

216 return getattr(self, "_can_skip_downstream", False) 

217 

218 def as_setup(self): 

219 self.is_setup = True 

220 return self 

221 

222 def as_teardown( 

223 self, 

224 *, 

225 setups: BaseOperator | Iterable[BaseOperator] | None = None, 

226 on_failure_fail_dagrun: bool | None = None, 

227 ): 

228 self.is_teardown = True 

229 self.trigger_rule = TriggerRule.ALL_DONE_SETUP_SUCCESS 

230 if on_failure_fail_dagrun is not None: 

231 self.on_failure_fail_dagrun = on_failure_fail_dagrun 

232 if setups is not None: 

233 setups = [setups] if isinstance(setups, DependencyMixin) else setups 

234 for s in setups: 

235 s.is_setup = True 

236 s >> self 

237 return self 

238 

239 def __enter__(self): 

240 if not self.is_setup and not self.is_teardown: 

241 raise RuntimeError("Only setup/teardown tasks can be used as context managers.") 

242 SetupTeardownContext.push_setup_teardown_task(self) 

243 return SetupTeardownContext 

244 

245 def __exit__(self, exc_type, exc_val, exc_tb): 

246 SetupTeardownContext.set_work_task_roots_and_leaves() 

247 

248 # TODO: Task-SDK -- Should the following methods removed? 

249 # get_template_env 

250 # _render 

251 def get_template_env(self, dag: DAG | None = None) -> jinja2.Environment: 

252 """Get the template environment for rendering templates.""" 

253 if dag is None: 

254 dag = self.get_dag() 

255 return super().get_template_env(dag=dag) 

256 

257 def _render(self, template, context, dag: DAG | None = None): 

258 if dag is None: 

259 dag = self.get_dag() 

260 return super()._render(template, context, dag=dag) 

261 

262 def _do_render_template_fields( 

263 self, 

264 parent: Any, 

265 template_fields: Iterable[str], 

266 context: Context, 

267 jinja_env: jinja2.Environment, 

268 seen_oids: set[int], 

269 ) -> None: 

270 """Override the base to use custom error logging.""" 

271 for attr_name in template_fields: 

272 try: 

273 value = getattr(parent, attr_name) 

274 except AttributeError: 

275 raise AttributeError( 

276 f"{attr_name!r} is configured as a template field " 

277 f"but {parent.task_type} does not have this attribute." 

278 ) 

279 try: 

280 if not value: 

281 continue 

282 except Exception: 

283 # This may happen if the templated field points to a class which does not support `__bool__`, 

284 # such as Pandas DataFrames: 

285 # https://github.com/pandas-dev/pandas/blob/9135c3aaf12d26f857fcc787a5b64d521c51e379/pandas/core/generic.py#L1465 

286 log.info( 

287 "Unable to check if the value of type '%s' is False for task '%s', field '%s'.", 

288 type(value).__name__, 

289 self.task_id, 

290 attr_name, 

291 ) 

292 # We may still want to render custom classes which do not support __bool__ 

293 pass 

294 

295 try: 

296 if callable(value): 

297 rendered_content = value(context=context, jinja_env=jinja_env) 

298 else: 

299 rendered_content = self.render_template(value, context, jinja_env, seen_oids) 

300 except Exception: 

301 # Mask sensitive values in the template before logging 

302 from airflow.sdk._shared.secrets_masker import redact 

303 

304 masked_value = redact(value) 

305 log.exception( 

306 "Exception rendering Jinja template for task '%s', field '%s'. Template: %r", 

307 self.task_id, 

308 attr_name, 

309 masked_value, 

310 ) 

311 raise 

312 else: 

313 setattr(parent, attr_name, rendered_content) 

314 

315 def _iter_all_mapped_downstreams(self) -> Iterator[MappedOperator | MappedTaskGroup]: 

316 """ 

317 Return mapped nodes that are direct dependencies of the current task. 

318 

319 For now, this walks the entire Dag to find mapped nodes that has this 

320 current task as an upstream. We cannot use ``downstream_list`` since it 

321 only contains operators, not task groups. In the future, we should 

322 provide a way to record a Dag node's all downstream nodes instead. 

323 

324 Note that this does not guarantee the returned tasks actually use the 

325 current task for task mapping, but only checks those task are mapped 

326 operators, and are downstreams of the current task. 

327 

328 To get a list of tasks that uses the current task for task mapping, use 

329 :meth:`iter_mapped_dependants` instead. 

330 """ 

331 from airflow.sdk.definitions.mappedoperator import MappedOperator 

332 from airflow.sdk.definitions.taskgroup import MappedTaskGroup, TaskGroup 

333 

334 def _walk_group(group: TaskGroup) -> Iterable[tuple[str, DAGNode]]: 

335 """ 

336 Recursively walk children in a task group. 

337 

338 This yields all direct children (including both tasks and task 

339 groups), and all children of any task groups. 

340 """ 

341 for key, child in group.children.items(): 

342 yield key, child 

343 if isinstance(child, TaskGroup): 

344 yield from _walk_group(child) 

345 

346 dag = self.get_dag() 

347 if not dag: 

348 raise RuntimeError("Cannot check for mapped dependants when not attached to a Dag") 

349 for key, child in _walk_group(dag.task_group): 

350 if key == self.node_id: 

351 continue 

352 if not isinstance(child, (MappedOperator, MappedTaskGroup)): 

353 continue 

354 if self.node_id in child.upstream_task_ids: 

355 yield child 

356 

357 def iter_mapped_dependants(self) -> Iterator[MappedOperator | MappedTaskGroup]: 

358 """ 

359 Return mapped nodes that depend on the current task the expansion. 

360 

361 For now, this walks the entire Dag to find mapped nodes that has this 

362 current task as an upstream. We cannot use ``downstream_list`` since it 

363 only contains operators, not task groups. In the future, we should 

364 provide a way to record a Dag node's all downstream nodes instead. 

365 """ 

366 return ( 

367 downstream 

368 for downstream in self._iter_all_mapped_downstreams() 

369 if any(p.node_id == self.node_id for p in downstream.iter_mapped_dependencies()) 

370 ) 

371 

372 def iter_mapped_task_groups(self) -> Iterator[MappedTaskGroup]: 

373 """ 

374 Return mapped task groups this task belongs to. 

375 

376 Groups are returned from the innermost to the outmost. 

377 

378 :meta private: 

379 """ 

380 if (group := self.task_group) is None: 

381 return 

382 yield from group.iter_mapped_task_groups() 

383 

384 def get_closest_mapped_task_group(self) -> MappedTaskGroup | None: 

385 """ 

386 Get the mapped task group "closest" to this task in the Dag. 

387 

388 :meta private: 

389 """ 

390 return next(self.iter_mapped_task_groups(), None) 

391 

392 def get_needs_expansion(self) -> bool: 

393 """ 

394 Return true if the task is MappedOperator or is in a mapped task group. 

395 

396 :meta private: 

397 """ 

398 if self._needs_expansion is None: 

399 if self.get_closest_mapped_task_group() is not None: 

400 self._needs_expansion = True 

401 else: 

402 self._needs_expansion = False 

403 return self._needs_expansion