Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/airflow/sdk/definitions/_internal/abstractoperator.py: 47%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

188 statements  

1# 

2# Licensed to the Apache Software Foundation (ASF) under one 

3# or more contributor license agreements. See the NOTICE file 

4# distributed with this work for additional information 

5# regarding copyright ownership. The ASF licenses this file 

6# to you under the Apache License, Version 2.0 (the 

7# "License"); you may not use this file except in compliance 

8# with the License. You may obtain a copy of the License at 

9# 

10# http://www.apache.org/licenses/LICENSE-2.0 

11# 

12# Unless required by applicable law or agreed to in writing, 

13# software distributed under the License is distributed on an 

14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 

15# KIND, either express or implied. See the License for the 

16# specific language governing permissions and limitations 

17# under the License. 

18from __future__ import annotations 

19 

20import datetime 

21import logging 

22from abc import abstractmethod 

23from collections.abc import ( 

24 Callable, 

25 Collection, 

26 Iterable, 

27 Iterator, 

28) 

29from typing import TYPE_CHECKING, Any, ClassVar, TypeAlias 

30 

31from airflow.sdk import TriggerRule, WeightRule 

32from airflow.sdk.configuration import conf 

33from airflow.sdk.definitions._internal.mixins import DependencyMixin 

34from airflow.sdk.definitions._internal.node import DAGNode 

35from airflow.sdk.definitions._internal.setup_teardown import SetupTeardownContext 

36from airflow.sdk.definitions._internal.templater import Templater 

37from airflow.sdk.definitions.context import Context 

38 

39if TYPE_CHECKING: 

40 import jinja2 

41 

42 from airflow.sdk.bases.operator import BaseOperator 

43 from airflow.sdk.bases.operatorlink import BaseOperatorLink 

44 from airflow.sdk.definitions.dag import DAG 

45 from airflow.sdk.definitions.mappedoperator import MappedOperator 

46 from airflow.sdk.definitions.taskgroup import MappedTaskGroup 

47 

48TaskStateChangeCallback = Callable[[Context], None] 

49TaskStateChangeCallbackAttrType: TypeAlias = TaskStateChangeCallback | list[TaskStateChangeCallback] | None 

50 

51DEFAULT_OWNER: str = conf.get_mandatory_value("operators", "default_owner") 

52DEFAULT_POOL_SLOTS: int = 1 

53DEFAULT_POOL_NAME = "default_pool" 

54DEFAULT_PRIORITY_WEIGHT: int = 1 

55# Databases do not support arbitrary precision integers, so we need to limit the range of priority weights. 

56# postgres: -2147483648 to +2147483647 (see https://www.postgresql.org/docs/current/datatype-numeric.html) 

57# mysql: -2147483648 to +2147483647 (see https://dev.mysql.com/doc/refman/8.4/en/integer-types.html) 

58# sqlite: -9223372036854775808 to +9223372036854775807 (see https://sqlite.org/datatype3.html) 

59MINIMUM_PRIORITY_WEIGHT: int = -2147483648 

60MAXIMUM_PRIORITY_WEIGHT: int = 2147483647 

61DEFAULT_EXECUTOR: str | None = None 

62DEFAULT_QUEUE: str = conf.get_mandatory_value("operators", "default_queue") 

63DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST: bool = False 

64DEFAULT_WAIT_FOR_PAST_DEPENDS_BEFORE_SKIPPING: bool = False 

65DEFAULT_RETRIES: int = conf.getint("core", "default_task_retries", fallback=0) 

66DEFAULT_RETRY_DELAY: datetime.timedelta = datetime.timedelta( 

67 seconds=conf.getint("core", "default_task_retry_delay", fallback=300) 

68) 

69MAX_RETRY_DELAY: int = conf.getint("core", "max_task_retry_delay", fallback=24 * 60 * 60) 

70 

71# TODO: Task-SDK -- these defaults should be overridable from the Airflow config 

72DEFAULT_TRIGGER_RULE: TriggerRule = TriggerRule.ALL_SUCCESS 

73DEFAULT_WEIGHT_RULE: WeightRule = WeightRule( 

74 conf.get("core", "default_task_weight_rule", fallback=WeightRule.DOWNSTREAM) 

75) 

76DEFAULT_TASK_EXECUTION_TIMEOUT: datetime.timedelta | None = conf.gettimedelta( 

77 "core", "default_task_execution_timeout" 

78) 

79DEFAULT_EMAIL_ON_FAILURE: bool = conf.getboolean("email", "default_email_on_failure", fallback=True) 

80DEFAULT_EMAIL_ON_RETRY: bool = conf.getboolean("email", "default_email_on_retry", fallback=True) 

81log = logging.getLogger(__name__) 

82 

83 

84class AbstractOperator(Templater, DAGNode): 

85 """ 

86 Common implementation for operators, including unmapped and mapped. 

87 

88 This base class is more about sharing implementations, not defining a common 

89 interface. Unfortunately it's difficult to use this as the common base class 

90 for typing due to BaseOperator carrying too much historical baggage. 

91 

92 The union type ``from airflow.models.operator import Operator`` is easier 

93 to use for typing purposes. 

94 

95 :meta private: 

96 """ 

97 

98 operator_class: type[BaseOperator] | dict[str, Any] 

99 

100 priority_weight: int 

101 

102 # Defines the operator level extra links. 

103 operator_extra_links: Collection[BaseOperatorLink] = () 

104 

105 owner: str 

106 task_id: str 

107 

108 outlets: list 

109 inlets: list 

110 

111 trigger_rule: TriggerRule 

112 _needs_expansion: bool | None = None 

113 _on_failure_fail_dagrun = False 

114 is_setup: bool = False 

115 is_teardown: bool = False 

116 

117 HIDE_ATTRS_FROM_UI: ClassVar[frozenset[str]] = frozenset( 

118 ( 

119 "log", 

120 "dag", # We show dag_id, don't need to show this too 

121 "node_id", # Duplicates task_id 

122 "task_group", # Doesn't have a useful repr, no point showing in UI 

123 "inherits_from_empty_operator", # impl detail 

124 "inherits_from_skipmixin", # impl detail 

125 # Decide whether to start task execution from triggerer 

126 "start_trigger_args", 

127 "start_from_trigger", 

128 # For compatibility with TG, for operators these are just the current task, no point showing 

129 "roots", 

130 "leaves", 

131 # These lists are already shown via *_task_ids 

132 "upstream_list", 

133 "downstream_list", 

134 # Not useful, implementation detail, already shown elsewhere 

135 "global_operator_extra_link_dict", 

136 "operator_extra_link_dict", 

137 ) 

138 ) 

139 

140 @property 

141 def is_async(self) -> bool: 

142 return False 

143 

144 @property 

145 def task_type(self) -> str: 

146 raise NotImplementedError() 

147 

148 @property 

149 def operator_name(self) -> str: 

150 raise NotImplementedError() 

151 

152 _is_sensor: bool = False 

153 _is_mapped: bool = False 

154 _can_skip_downstream: bool = False 

155 

156 @property 

157 def dag_id(self) -> str: 

158 """Returns dag id if it has one or an adhoc + owner.""" 

159 dag = self.get_dag() 

160 if dag: 

161 return dag.dag_id 

162 return f"adhoc_{self.owner}" 

163 

164 @property 

165 def node_id(self) -> str: 

166 return self.task_id 

167 

168 @property 

169 @abstractmethod 

170 def task_display_name(self) -> str: ... 

171 

172 @property 

173 def is_mapped(self): 

174 return self._is_mapped 

175 

176 @property 

177 def label(self) -> str | None: 

178 if self.task_display_name and self.task_display_name != self.task_id: 

179 return self.task_display_name 

180 # Prefix handling if no display is given is cloned from taskmixin for compatibility 

181 tg = self.task_group 

182 if tg and tg.node_id and tg.prefix_group_id: 

183 # "task_group_id.task_id" -> "task_id" 

184 return self.task_id[len(tg.node_id) + 1 :] 

185 return self.task_id 

186 

187 @property 

188 def on_failure_fail_dagrun(self): 

189 """ 

190 Whether the operator should fail the dagrun on failure. 

191 

192 :meta private: 

193 """ 

194 return self._on_failure_fail_dagrun 

195 

196 @on_failure_fail_dagrun.setter 

197 def on_failure_fail_dagrun(self, value): 

198 """ 

199 Setter for on_failure_fail_dagrun property. 

200 

201 :meta private: 

202 """ 

203 if value is True and self.is_teardown is not True: 

204 raise ValueError( 

205 f"Cannot set task on_failure_fail_dagrun for " 

206 f"'{self.task_id}' because it is not a teardown task." 

207 ) 

208 self._on_failure_fail_dagrun = value 

209 

210 @property 

211 def inherits_from_empty_operator(self): 

212 """Used to determine if an Operator is inherited from EmptyOperator.""" 

213 # This looks like `isinstance(self, EmptyOperator) would work, but this also 

214 # needs to cope when `self` is a Serialized instance of a EmptyOperator or one 

215 # of its subclasses (which don't inherit from anything but BaseOperator). 

216 return getattr(self, "_is_empty", False) 

217 

218 @property 

219 def inherits_from_skipmixin(self): 

220 """Used to determine if an Operator is inherited from SkipMixin or its subclasses (e.g., BranchMixin).""" 

221 return getattr(self, "_can_skip_downstream", False) 

222 

223 def as_setup(self): 

224 self.is_setup = True 

225 return self 

226 

227 def as_teardown( 

228 self, 

229 *, 

230 setups: BaseOperator | Iterable[BaseOperator] | None = None, 

231 on_failure_fail_dagrun: bool | None = None, 

232 ): 

233 self.is_teardown = True 

234 self.trigger_rule = TriggerRule.ALL_DONE_SETUP_SUCCESS 

235 if on_failure_fail_dagrun is not None: 

236 self.on_failure_fail_dagrun = on_failure_fail_dagrun 

237 if setups is not None: 

238 setups = [setups] if isinstance(setups, DependencyMixin) else setups 

239 for s in setups: 

240 s.is_setup = True 

241 s >> self 

242 return self 

243 

244 def __enter__(self): 

245 if not self.is_setup and not self.is_teardown: 

246 raise RuntimeError("Only setup/teardown tasks can be used as context managers.") 

247 SetupTeardownContext.push_setup_teardown_task(self) 

248 return SetupTeardownContext 

249 

250 def __exit__(self, exc_type, exc_val, exc_tb): 

251 SetupTeardownContext.set_work_task_roots_and_leaves() 

252 

253 # TODO: Task-SDK -- Should the following methods removed? 

254 # get_template_env 

255 # _render 

256 def get_template_env(self, dag: DAG | None = None) -> jinja2.Environment: 

257 """Get the template environment for rendering templates.""" 

258 if dag is None: 

259 dag = self.get_dag() 

260 return super().get_template_env(dag=dag) 

261 

262 def _render(self, template, context, dag: DAG | None = None): 

263 if dag is None: 

264 dag = self.get_dag() 

265 return super()._render(template, context, dag=dag) 

266 

267 def _do_render_template_fields( 

268 self, 

269 parent: Any, 

270 template_fields: Iterable[str], 

271 context: Context, 

272 jinja_env: jinja2.Environment, 

273 seen_oids: set[int], 

274 ) -> None: 

275 """Override the base to use custom error logging.""" 

276 for attr_name in template_fields: 

277 try: 

278 value = getattr(parent, attr_name) 

279 except AttributeError: 

280 raise AttributeError( 

281 f"{attr_name!r} is configured as a template field " 

282 f"but {parent.task_type} does not have this attribute." 

283 ) 

284 try: 

285 if not value: 

286 continue 

287 except Exception: 

288 # This may happen if the templated field points to a class which does not support `__bool__`, 

289 # such as Pandas DataFrames: 

290 # https://github.com/pandas-dev/pandas/blob/9135c3aaf12d26f857fcc787a5b64d521c51e379/pandas/core/generic.py#L1465 

291 log.info( 

292 "Unable to check if the value of type '%s' is False for task '%s', field '%s'.", 

293 type(value).__name__, 

294 self.task_id, 

295 attr_name, 

296 ) 

297 # We may still want to render custom classes which do not support __bool__ 

298 pass 

299 

300 try: 

301 if callable(value): 

302 rendered_content = value(context=context, jinja_env=jinja_env) 

303 else: 

304 rendered_content = self.render_template(value, context, jinja_env, seen_oids) 

305 except Exception: 

306 # Mask sensitive values in the template before logging 

307 from airflow.sdk._shared.secrets_masker import redact 

308 

309 masked_value = redact(value) 

310 log.exception( 

311 "Exception rendering Jinja template for task '%s', field '%s'. Template: %r", 

312 self.task_id, 

313 attr_name, 

314 masked_value, 

315 ) 

316 raise 

317 else: 

318 setattr(parent, attr_name, rendered_content) 

319 

320 def _iter_all_mapped_downstreams(self) -> Iterator[MappedOperator | MappedTaskGroup]: 

321 """ 

322 Return mapped nodes that are direct dependencies of the current task. 

323 

324 For now, this walks the entire Dag to find mapped nodes that has this 

325 current task as an upstream. We cannot use ``downstream_list`` since it 

326 only contains operators, not task groups. In the future, we should 

327 provide a way to record a Dag node's all downstream nodes instead. 

328 

329 Note that this does not guarantee the returned tasks actually use the 

330 current task for task mapping, but only checks those task are mapped 

331 operators, and are downstreams of the current task. 

332 

333 To get a list of tasks that uses the current task for task mapping, use 

334 :meth:`iter_mapped_dependants` instead. 

335 """ 

336 from airflow.sdk.definitions.mappedoperator import MappedOperator 

337 from airflow.sdk.definitions.taskgroup import MappedTaskGroup, TaskGroup 

338 

339 def _walk_group(group: TaskGroup) -> Iterable[tuple[str, DAGNode]]: 

340 """ 

341 Recursively walk children in a task group. 

342 

343 This yields all direct children (including both tasks and task 

344 groups), and all children of any task groups. 

345 """ 

346 for key, child in group.children.items(): 

347 yield key, child 

348 if isinstance(child, TaskGroup): 

349 yield from _walk_group(child) 

350 

351 dag = self.get_dag() 

352 if not dag: 

353 raise RuntimeError("Cannot check for mapped dependants when not attached to a Dag") 

354 for key, child in _walk_group(dag.task_group): 

355 if key == self.node_id: 

356 continue 

357 if not isinstance(child, (MappedOperator, MappedTaskGroup)): 

358 continue 

359 if self.node_id in child.upstream_task_ids: 

360 yield child 

361 

362 def iter_mapped_dependants(self) -> Iterator[MappedOperator | MappedTaskGroup]: 

363 """ 

364 Return mapped nodes that depend on the current task the expansion. 

365 

366 For now, this walks the entire Dag to find mapped nodes that has this 

367 current task as an upstream. We cannot use ``downstream_list`` since it 

368 only contains operators, not task groups. In the future, we should 

369 provide a way to record a Dag node's all downstream nodes instead. 

370 """ 

371 return ( 

372 downstream 

373 for downstream in self._iter_all_mapped_downstreams() 

374 if any(p.node_id == self.node_id for p in downstream.iter_mapped_dependencies()) 

375 ) 

376 

377 def iter_mapped_task_groups(self) -> Iterator[MappedTaskGroup]: 

378 """ 

379 Return mapped task groups this task belongs to. 

380 

381 Groups are returned from the innermost to the outmost. 

382 

383 :meta private: 

384 """ 

385 if (group := self.task_group) is None: 

386 return 

387 yield from group.iter_mapped_task_groups() 

388 

389 def get_closest_mapped_task_group(self) -> MappedTaskGroup | None: 

390 """ 

391 Get the mapped task group "closest" to this task in the Dag. 

392 

393 :meta private: 

394 """ 

395 return next(self.iter_mapped_task_groups(), None) 

396 

397 def get_needs_expansion(self) -> bool: 

398 """ 

399 Return true if the task is MappedOperator or is in a mapped task group. 

400 

401 :meta private: 

402 """ 

403 if self._needs_expansion is None: 

404 if self.get_closest_mapped_task_group() is not None: 

405 self._needs_expansion = True 

406 else: 

407 self._needs_expansion = False 

408 return self._needs_expansion