Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/airflow/sdk/definitions/_internal/abstractoperator.py: 45%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

197 statements  

1# 

2# Licensed to the Apache Software Foundation (ASF) under one 

3# or more contributor license agreements. See the NOTICE file 

4# distributed with this work for additional information 

5# regarding copyright ownership. The ASF licenses this file 

6# to you under the Apache License, Version 2.0 (the 

7# "License"); you may not use this file except in compliance 

8# with the License. You may obtain a copy of the License at 

9# 

10# http://www.apache.org/licenses/LICENSE-2.0 

11# 

12# Unless required by applicable law or agreed to in writing, 

13# software distributed under the License is distributed on an 

14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 

15# KIND, either express or implied. See the License for the 

16# specific language governing permissions and limitations 

17# under the License. 

18from __future__ import annotations 

19 

20import datetime 

21import logging 

22from abc import abstractmethod 

23from collections.abc import ( 

24 Callable, 

25 Collection, 

26 Iterable, 

27 Iterator, 

28) 

29from typing import TYPE_CHECKING, Any, ClassVar, TypeAlias 

30 

31from airflow.sdk import TriggerRule, WeightRule 

32from airflow.sdk.configuration import conf 

33from airflow.sdk.definitions._internal.mixins import DependencyMixin 

34from airflow.sdk.definitions._internal.node import DAGNode 

35from airflow.sdk.definitions._internal.setup_teardown import SetupTeardownContext 

36from airflow.sdk.definitions._internal.templater import Templater 

37from airflow.sdk.definitions.context import Context 

38 

39if TYPE_CHECKING: 

40 import jinja2 

41 

42 from airflow.sdk.bases.operator import BaseOperator 

43 from airflow.sdk.bases.operatorlink import BaseOperatorLink 

44 from airflow.sdk.definitions.dag import DAG 

45 from airflow.sdk.definitions.mappedoperator import MappedOperator 

46 from airflow.sdk.definitions.taskgroup import MappedTaskGroup 

47 

48TaskStateChangeCallback = Callable[[Context], None] 

49TaskStateChangeCallbackAttrType: TypeAlias = TaskStateChangeCallback | list[TaskStateChangeCallback] | None 

50 

51DEFAULT_OWNER: str = conf.get_mandatory_value("operators", "default_owner") 

52DEFAULT_POOL_SLOTS: int = 1 

53DEFAULT_POOL_NAME = "default_pool" 

54DEFAULT_PRIORITY_WEIGHT: int = 1 

55# Databases do not support arbitrary precision integers, so we need to limit the range of priority weights. 

56# postgres: -2147483648 to +2147483647 (see https://www.postgresql.org/docs/current/datatype-numeric.html) 

57# mysql: -2147483648 to +2147483647 (see https://dev.mysql.com/doc/refman/8.4/en/integer-types.html) 

58# sqlite: -9223372036854775808 to +9223372036854775807 (see https://sqlite.org/datatype3.html) 

59MINIMUM_PRIORITY_WEIGHT: int = -2147483648 

60MAXIMUM_PRIORITY_WEIGHT: int = 2147483647 

61DEFAULT_EXECUTOR: str | None = None 

62DEFAULT_QUEUE: str = conf.get_mandatory_value("operators", "default_queue") 

63DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST: bool = False 

64DEFAULT_WAIT_FOR_PAST_DEPENDS_BEFORE_SKIPPING: bool = False 

65DEFAULT_RETRIES: int = conf.getint("core", "default_task_retries", fallback=0) 

66DEFAULT_RETRY_DELAY: datetime.timedelta = datetime.timedelta( 

67 seconds=conf.getint("core", "default_task_retry_delay", fallback=300) 

68) 

69MAX_RETRY_DELAY: int = conf.getint("core", "max_task_retry_delay", fallback=24 * 60 * 60) 

70 

71# TODO: Task-SDK -- these defaults should be overridable from the Airflow config 

72DEFAULT_TRIGGER_RULE: TriggerRule = TriggerRule.ALL_SUCCESS 

73DEFAULT_WEIGHT_RULE: WeightRule = WeightRule( 

74 conf.get("core", "default_task_weight_rule", fallback=WeightRule.DOWNSTREAM) 

75) 

76DEFAULT_TASK_EXECUTION_TIMEOUT: datetime.timedelta | None = conf.gettimedelta( 

77 "core", "default_task_execution_timeout" 

78) 

79DEFAULT_EMAIL_ON_FAILURE: bool = conf.getboolean("email", "default_email_on_failure", fallback=True) 

80DEFAULT_EMAIL_ON_RETRY: bool = conf.getboolean("email", "default_email_on_retry", fallback=True) 

81log = logging.getLogger(__name__) 

82 

83 

84class AbstractOperator(Templater, DAGNode): 

85 """ 

86 Common implementation for operators, including unmapped and mapped. 

87 

88 This base class is more about sharing implementations, not defining a common 

89 interface. Unfortunately it's difficult to use this as the common base class 

90 for typing due to BaseOperator carrying too much historical baggage. 

91 

92 The union type ``from airflow.models.operator import Operator`` is easier 

93 to use for typing purposes. 

94 

95 :meta private: 

96 """ 

97 

98 operator_class: type[BaseOperator] | dict[str, Any] 

99 

100 priority_weight: int 

101 

102 # Defines the operator level extra links. 

103 operator_extra_links: Collection[BaseOperatorLink] = () 

104 

105 owner: str 

106 task_id: str 

107 

108 outlets: list 

109 inlets: list 

110 

111 trigger_rule: TriggerRule 

112 _needs_expansion: bool | None = None 

113 _on_failure_fail_dagrun = False 

114 is_setup: bool = False 

115 is_teardown: bool = False 

116 

117 HIDE_ATTRS_FROM_UI: ClassVar[frozenset[str]] = frozenset( 

118 ( 

119 "log", 

120 "dag", # We show dag_id, don't need to show this too 

121 "node_id", # Duplicates task_id 

122 "task_group", # Doesn't have a useful repr, no point showing in UI 

123 "inherits_from_empty_operator", # impl detail 

124 "inherits_from_skipmixin", # impl detail 

125 # Decide whether to start task execution from triggerer 

126 "start_trigger_args", 

127 "start_from_trigger", 

128 # For compatibility with TG, for operators these are just the current task, no point showing 

129 "roots", 

130 "leaves", 

131 # These lists are already shown via *_task_ids 

132 "upstream_list", 

133 "downstream_list", 

134 # Not useful, implementation detail, already shown elsewhere 

135 "global_operator_extra_link_dict", 

136 "operator_extra_link_dict", 

137 ) 

138 ) 

139 

140 @property 

141 def is_async(self) -> bool: 

142 return False 

143 

144 @property 

145 def task_type(self) -> str: 

146 raise NotImplementedError() 

147 

148 @property 

149 def operator_name(self) -> str: 

150 raise NotImplementedError() 

151 

152 _is_sensor: bool = False 

153 _is_mapped: bool = False 

154 _can_skip_downstream: bool = False 

155 

156 @property 

157 def dag_id(self) -> str: 

158 """Returns dag id if it has one or an adhoc + owner.""" 

159 dag = self.get_dag() 

160 if dag: 

161 return dag.dag_id 

162 return f"adhoc_{self.owner}" 

163 

164 @property 

165 def node_id(self) -> str: 

166 return self.task_id 

167 

168 @property 

169 @abstractmethod 

170 def task_display_name(self) -> str: ... 

171 

172 @property 

173 def is_mapped(self): 

174 return self._is_mapped 

175 

176 @property 

177 def label(self) -> str | None: 

178 if self.task_display_name and self.task_display_name != self.task_id: 

179 return self.task_display_name 

180 # Prefix handling if no display is given is cloned from taskmixin for compatibility 

181 tg = self.task_group 

182 if tg and tg.node_id and tg.prefix_group_id: 

183 # "task_group_id.task_id" -> "task_id" 

184 return self.task_id[len(tg.node_id) + 1 :] 

185 return self.task_id 

186 

187 @property 

188 def on_failure_fail_dagrun(self): 

189 """ 

190 Whether the operator should fail the dagrun on failure. 

191 

192 :meta private: 

193 """ 

194 return self._on_failure_fail_dagrun 

195 

196 @on_failure_fail_dagrun.setter 

197 def on_failure_fail_dagrun(self, value): 

198 """ 

199 Setter for on_failure_fail_dagrun property. 

200 

201 :meta private: 

202 """ 

203 if value is True and self.is_teardown is not True: 

204 raise ValueError( 

205 f"Cannot set task on_failure_fail_dagrun for " 

206 f"'{self.task_id}' because it is not a teardown task." 

207 ) 

208 self._on_failure_fail_dagrun = value 

209 

210 @property 

211 def inherits_from_empty_operator(self): 

212 """Used to determine if an Operator is inherited from EmptyOperator.""" 

213 # This looks like `isinstance(self, EmptyOperator) would work, but this also 

214 # needs to cope when `self` is a Serialized instance of a EmptyOperator or one 

215 # of its subclasses (which don't inherit from anything but BaseOperator). 

216 return getattr(self, "_is_empty", False) 

217 

218 @property 

219 def inherits_from_skipmixin(self): 

220 """Used to determine if an Operator is inherited from SkipMixin or its subclasses (e.g., BranchMixin).""" 

221 return getattr(self, "_can_skip_downstream", False) 

222 

223 def as_setup(self): 

224 self.is_setup = True 

225 return self 

226 

227 def as_teardown( 

228 self, 

229 *, 

230 setups: BaseOperator | Iterable[BaseOperator] | None = None, 

231 on_failure_fail_dagrun: bool | None = None, 

232 ): 

233 self.is_teardown = True 

234 self.trigger_rule = TriggerRule.ALL_DONE_SETUP_SUCCESS 

235 if on_failure_fail_dagrun is not None: 

236 self.on_failure_fail_dagrun = on_failure_fail_dagrun 

237 if setups is not None: 

238 setups = [setups] if isinstance(setups, DependencyMixin) else setups 

239 for s in setups: 

240 s.is_setup = True 

241 s >> self 

242 return self 

243 

244 def __enter__(self): 

245 if not self.is_setup and not self.is_teardown: 

246 raise RuntimeError("Only setup/teardown tasks can be used as context managers.") 

247 SetupTeardownContext.push_setup_teardown_task(self) 

248 return SetupTeardownContext 

249 

250 def __exit__(self, exc_type, exc_val, exc_tb): 

251 SetupTeardownContext.set_work_task_roots_and_leaves() 

252 

253 # TODO: Task-SDK -- Should the following methods removed? 

254 # get_template_env 

255 # _render 

256 def get_template_env(self, dag: DAG | None = None) -> jinja2.Environment: 

257 """Get the template environment for rendering templates.""" 

258 from airflow.sdk.definitions._internal.templater import create_template_env 

259 

260 if dag is None: 

261 dag = self.get_dag() 

262 # Check if the operator has an explicit native rendering preference 

263 render_op_template_as_native_obj = getattr(self, "render_template_as_native_obj", None) 

264 if render_op_template_as_native_obj is not None: 

265 if dag: 

266 # Use dag's template settings (searchpath, macros, filters, etc.) 

267 searchpath = [dag.folder] 

268 if dag.template_searchpath: 

269 searchpath += dag.template_searchpath 

270 return create_template_env( 

271 native=render_op_template_as_native_obj, 

272 searchpath=searchpath, 

273 template_undefined=dag.template_undefined, 

274 jinja_environment_kwargs=dag.jinja_environment_kwargs, 

275 user_defined_macros=dag.user_defined_macros, 

276 user_defined_filters=dag.user_defined_filters, 

277 ) 

278 # No dag context available, use minimal template env 

279 return create_template_env(native=render_op_template_as_native_obj) 

280 # No operator-level override, delegate to parent class 

281 return super().get_template_env(dag=dag) 

282 

283 def _render(self, template, context, dag: DAG | None = None): 

284 if dag is None: 

285 dag = self.get_dag() 

286 return super()._render(template, context, dag=dag) 

287 

288 def _do_render_template_fields( 

289 self, 

290 parent: Any, 

291 template_fields: Iterable[str], 

292 context: Context, 

293 jinja_env: jinja2.Environment, 

294 seen_oids: set[int], 

295 ) -> None: 

296 """Override the base to use custom error logging.""" 

297 for attr_name in template_fields: 

298 try: 

299 value = getattr(parent, attr_name) 

300 except AttributeError: 

301 raise AttributeError( 

302 f"{attr_name!r} is configured as a template field " 

303 f"but {parent.task_type} does not have this attribute." 

304 ) 

305 try: 

306 if not value: 

307 continue 

308 except Exception: 

309 # This may happen if the templated field points to a class which does not support `__bool__`, 

310 # such as Pandas DataFrames: 

311 # https://github.com/pandas-dev/pandas/blob/9135c3aaf12d26f857fcc787a5b64d521c51e379/pandas/core/generic.py#L1465 

312 log.info( 

313 "Unable to check if the value of type '%s' is False for task '%s', field '%s'.", 

314 type(value).__name__, 

315 self.task_id, 

316 attr_name, 

317 ) 

318 # We may still want to render custom classes which do not support __bool__ 

319 pass 

320 

321 try: 

322 if callable(value): 

323 rendered_content = value(context=context, jinja_env=jinja_env) 

324 else: 

325 rendered_content = self.render_template(value, context, jinja_env, seen_oids) 

326 except Exception: 

327 # Mask sensitive values in the template before logging 

328 from airflow.sdk._shared.secrets_masker import redact 

329 

330 masked_value = redact(value) 

331 log.exception( 

332 "Exception rendering Jinja template for task '%s', field '%s'. Template: %r", 

333 self.task_id, 

334 attr_name, 

335 masked_value, 

336 ) 

337 raise 

338 else: 

339 setattr(parent, attr_name, rendered_content) 

340 

341 def _iter_all_mapped_downstreams(self) -> Iterator[MappedOperator | MappedTaskGroup]: 

342 """ 

343 Return mapped nodes that are direct dependencies of the current task. 

344 

345 For now, this walks the entire Dag to find mapped nodes that has this 

346 current task as an upstream. We cannot use ``downstream_list`` since it 

347 only contains operators, not task groups. In the future, we should 

348 provide a way to record a Dag node's all downstream nodes instead. 

349 

350 Note that this does not guarantee the returned tasks actually use the 

351 current task for task mapping, but only checks those task are mapped 

352 operators, and are downstreams of the current task. 

353 

354 To get a list of tasks that uses the current task for task mapping, use 

355 :meth:`iter_mapped_dependants` instead. 

356 """ 

357 from airflow.sdk.definitions.mappedoperator import MappedOperator 

358 from airflow.sdk.definitions.taskgroup import MappedTaskGroup, TaskGroup 

359 

360 def _walk_group(group: TaskGroup) -> Iterable[tuple[str, DAGNode]]: 

361 """ 

362 Recursively walk children in a task group. 

363 

364 This yields all direct children (including both tasks and task 

365 groups), and all children of any task groups. 

366 """ 

367 for key, child in group.children.items(): 

368 yield key, child 

369 if isinstance(child, TaskGroup): 

370 yield from _walk_group(child) 

371 

372 dag = self.get_dag() 

373 if not dag: 

374 raise RuntimeError("Cannot check for mapped dependants when not attached to a Dag") 

375 for key, child in _walk_group(dag.task_group): 

376 if key == self.node_id: 

377 continue 

378 if not isinstance(child, (MappedOperator, MappedTaskGroup)): 

379 continue 

380 if self.node_id in child.upstream_task_ids: 

381 yield child 

382 

383 def iter_mapped_dependants(self) -> Iterator[MappedOperator | MappedTaskGroup]: 

384 """ 

385 Return mapped nodes that depend on the current task the expansion. 

386 

387 For now, this walks the entire Dag to find mapped nodes that has this 

388 current task as an upstream. We cannot use ``downstream_list`` since it 

389 only contains operators, not task groups. In the future, we should 

390 provide a way to record a Dag node's all downstream nodes instead. 

391 """ 

392 return ( 

393 downstream 

394 for downstream in self._iter_all_mapped_downstreams() 

395 if any(p.node_id == self.node_id for p in downstream.iter_mapped_dependencies()) 

396 ) 

397 

398 def iter_mapped_task_groups(self) -> Iterator[MappedTaskGroup]: 

399 """ 

400 Return mapped task groups this task belongs to. 

401 

402 Groups are returned from the innermost to the outmost. 

403 

404 :meta private: 

405 """ 

406 if (group := self.task_group) is None: 

407 return 

408 yield from group.iter_mapped_task_groups() 

409 

410 def get_closest_mapped_task_group(self) -> MappedTaskGroup | None: 

411 """ 

412 Get the mapped task group "closest" to this task in the Dag. 

413 

414 :meta private: 

415 """ 

416 return next(self.iter_mapped_task_groups(), None) 

417 

418 def get_needs_expansion(self) -> bool: 

419 """ 

420 Return true if the task is MappedOperator or is in a mapped task group. 

421 

422 :meta private: 

423 """ 

424 if self._needs_expansion is None: 

425 if self.get_closest_mapped_task_group() is not None: 

426 self._needs_expansion = True 

427 else: 

428 self._needs_expansion = False 

429 return self._needs_expansion