Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/airflow/sdk/definitions/_internal/abstractoperator.py: 47%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

185 statements  

1# 

2# Licensed to the Apache Software Foundation (ASF) under one 

3# or more contributor license agreements. See the NOTICE file 

4# distributed with this work for additional information 

5# regarding copyright ownership. The ASF licenses this file 

6# to you under the Apache License, Version 2.0 (the 

7# "License"); you may not use this file except in compliance 

8# with the License. You may obtain a copy of the License at 

9# 

10# http://www.apache.org/licenses/LICENSE-2.0 

11# 

12# Unless required by applicable law or agreed to in writing, 

13# software distributed under the License is distributed on an 

14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 

15# KIND, either express or implied. See the License for the 

16# specific language governing permissions and limitations 

17# under the License. 

18from __future__ import annotations 

19 

20import datetime 

21import logging 

22from abc import abstractmethod 

23from collections.abc import ( 

24 Callable, 

25 Collection, 

26 Iterable, 

27 Iterator, 

28) 

29from typing import TYPE_CHECKING, Any, ClassVar, TypeAlias 

30 

31from airflow.sdk import TriggerRule, WeightRule 

32from airflow.sdk.configuration import conf 

33from airflow.sdk.definitions._internal.mixins import DependencyMixin 

34from airflow.sdk.definitions._internal.node import DAGNode 

35from airflow.sdk.definitions._internal.setup_teardown import SetupTeardownContext 

36from airflow.sdk.definitions._internal.templater import Templater 

37from airflow.sdk.definitions.context import Context 

38 

39if TYPE_CHECKING: 

40 import jinja2 

41 

42 from airflow.sdk.bases.operator import BaseOperator 

43 from airflow.sdk.bases.operatorlink import BaseOperatorLink 

44 from airflow.sdk.definitions.dag import DAG 

45 from airflow.sdk.definitions.mappedoperator import MappedOperator 

46 from airflow.sdk.definitions.taskgroup import MappedTaskGroup 

47 

48TaskStateChangeCallback = Callable[[Context], None] 

49TaskStateChangeCallbackAttrType: TypeAlias = TaskStateChangeCallback | list[TaskStateChangeCallback] | None 

50 

51DEFAULT_OWNER: str = conf.get_mandatory_value("operators", "default_owner") 

52DEFAULT_POOL_SLOTS: int = 1 

53DEFAULT_POOL_NAME = "default_pool" 

54DEFAULT_PRIORITY_WEIGHT: int = 1 

55# Databases do not support arbitrary precision integers, so we need to limit the range of priority weights. 

56# postgres: -2147483648 to +2147483647 (see https://www.postgresql.org/docs/current/datatype-numeric.html) 

57# mysql: -2147483648 to +2147483647 (see https://dev.mysql.com/doc/refman/8.4/en/integer-types.html) 

58# sqlite: -9223372036854775808 to +9223372036854775807 (see https://sqlite.org/datatype3.html) 

59MINIMUM_PRIORITY_WEIGHT: int = -2147483648 

60MAXIMUM_PRIORITY_WEIGHT: int = 2147483647 

61DEFAULT_EXECUTOR: str | None = None 

62DEFAULT_QUEUE: str = conf.get_mandatory_value("operators", "default_queue") 

63DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST: bool = False 

64DEFAULT_WAIT_FOR_PAST_DEPENDS_BEFORE_SKIPPING: bool = False 

65DEFAULT_RETRIES: int = conf.getint("core", "default_task_retries", fallback=0) 

66DEFAULT_RETRY_DELAY: datetime.timedelta = datetime.timedelta( 

67 seconds=conf.getint("core", "default_task_retry_delay", fallback=300) 

68) 

69MAX_RETRY_DELAY: int = conf.getint("core", "max_task_retry_delay", fallback=24 * 60 * 60) 

70 

71# TODO: Task-SDK -- these defaults should be overridable from the Airflow config 

72DEFAULT_TRIGGER_RULE: TriggerRule = TriggerRule.ALL_SUCCESS 

73DEFAULT_WEIGHT_RULE: WeightRule = WeightRule( 

74 conf.get("core", "default_task_weight_rule", fallback=WeightRule.DOWNSTREAM) 

75) 

76DEFAULT_TASK_EXECUTION_TIMEOUT: datetime.timedelta | None = conf.gettimedelta( 

77 "core", "default_task_execution_timeout" 

78) 

79DEFAULT_EMAIL_ON_FAILURE: bool = conf.getboolean("email", "default_email_on_failure", fallback=True) 

80DEFAULT_EMAIL_ON_RETRY: bool = conf.getboolean("email", "default_email_on_retry", fallback=True) 

81log = logging.getLogger(__name__) 

82 

83 

84class AbstractOperator(Templater, DAGNode): 

85 """ 

86 Common implementation for operators, including unmapped and mapped. 

87 

88 This base class is more about sharing implementations, not defining a common 

89 interface. Unfortunately it's difficult to use this as the common base class 

90 for typing due to BaseOperator carrying too much historical baggage. 

91 

92 The union type ``from airflow.models.operator import Operator`` is easier 

93 to use for typing purposes. 

94 

95 :meta private: 

96 """ 

97 

98 operator_class: type[BaseOperator] | dict[str, Any] 

99 

100 priority_weight: int 

101 

102 # Defines the operator level extra links. 

103 operator_extra_links: Collection[BaseOperatorLink] = () 

104 

105 owner: str 

106 task_id: str 

107 

108 outlets: list 

109 inlets: list 

110 

111 trigger_rule: TriggerRule 

112 _needs_expansion: bool | None = None 

113 _on_failure_fail_dagrun = False 

114 is_setup: bool = False 

115 is_teardown: bool = False 

116 

117 HIDE_ATTRS_FROM_UI: ClassVar[frozenset[str]] = frozenset( 

118 ( 

119 "log", 

120 "dag", # We show dag_id, don't need to show this too 

121 "node_id", # Duplicates task_id 

122 "task_group", # Doesn't have a useful repr, no point showing in UI 

123 "inherits_from_empty_operator", # impl detail 

124 "inherits_from_skipmixin", # impl detail 

125 # Decide whether to start task execution from triggerer 

126 "start_trigger_args", 

127 "start_from_trigger", 

128 # For compatibility with TG, for operators these are just the current task, no point showing 

129 "roots", 

130 "leaves", 

131 # These lists are already shown via *_task_ids 

132 "upstream_list", 

133 "downstream_list", 

134 # Not useful, implementation detail, already shown elsewhere 

135 "global_operator_extra_link_dict", 

136 "operator_extra_link_dict", 

137 ) 

138 ) 

139 

140 @property 

141 def task_type(self) -> str: 

142 raise NotImplementedError() 

143 

144 @property 

145 def operator_name(self) -> str: 

146 raise NotImplementedError() 

147 

148 _is_sensor: bool = False 

149 _is_mapped: bool = False 

150 _can_skip_downstream: bool = False 

151 

152 @property 

153 def dag_id(self) -> str: 

154 """Returns dag id if it has one or an adhoc + owner.""" 

155 dag = self.get_dag() 

156 if dag: 

157 return dag.dag_id 

158 return f"adhoc_{self.owner}" 

159 

160 @property 

161 def node_id(self) -> str: 

162 return self.task_id 

163 

164 @property 

165 @abstractmethod 

166 def task_display_name(self) -> str: ... 

167 

168 @property 

169 def is_mapped(self): 

170 return self._is_mapped 

171 

172 @property 

173 def label(self) -> str | None: 

174 if self.task_display_name and self.task_display_name != self.task_id: 

175 return self.task_display_name 

176 # Prefix handling if no display is given is cloned from taskmixin for compatibility 

177 tg = self.task_group 

178 if tg and tg.node_id and tg.prefix_group_id: 

179 # "task_group_id.task_id" -> "task_id" 

180 return self.task_id[len(tg.node_id) + 1 :] 

181 return self.task_id 

182 

183 @property 

184 def on_failure_fail_dagrun(self): 

185 """ 

186 Whether the operator should fail the dagrun on failure. 

187 

188 :meta private: 

189 """ 

190 return self._on_failure_fail_dagrun 

191 

192 @on_failure_fail_dagrun.setter 

193 def on_failure_fail_dagrun(self, value): 

194 """ 

195 Setter for on_failure_fail_dagrun property. 

196 

197 :meta private: 

198 """ 

199 if value is True and self.is_teardown is not True: 

200 raise ValueError( 

201 f"Cannot set task on_failure_fail_dagrun for " 

202 f"'{self.task_id}' because it is not a teardown task." 

203 ) 

204 self._on_failure_fail_dagrun = value 

205 

206 @property 

207 def inherits_from_empty_operator(self): 

208 """Used to determine if an Operator is inherited from EmptyOperator.""" 

209 # This looks like `isinstance(self, EmptyOperator) would work, but this also 

210 # needs to cope when `self` is a Serialized instance of a EmptyOperator or one 

211 # of its subclasses (which don't inherit from anything but BaseOperator). 

212 return getattr(self, "_is_empty", False) 

213 

214 @property 

215 def inherits_from_skipmixin(self): 

216 """Used to determine if an Operator is inherited from SkipMixin or its subclasses (e.g., BranchMixin).""" 

217 return getattr(self, "_can_skip_downstream", False) 

218 

219 def as_setup(self): 

220 self.is_setup = True 

221 return self 

222 

223 def as_teardown( 

224 self, 

225 *, 

226 setups: BaseOperator | Iterable[BaseOperator] | None = None, 

227 on_failure_fail_dagrun: bool | None = None, 

228 ): 

229 self.is_teardown = True 

230 self.trigger_rule = TriggerRule.ALL_DONE_SETUP_SUCCESS 

231 if on_failure_fail_dagrun is not None: 

232 self.on_failure_fail_dagrun = on_failure_fail_dagrun 

233 if setups is not None: 

234 setups = [setups] if isinstance(setups, DependencyMixin) else setups 

235 for s in setups: 

236 s.is_setup = True 

237 s >> self 

238 return self 

239 

240 def __enter__(self): 

241 if not self.is_setup and not self.is_teardown: 

242 raise RuntimeError("Only setup/teardown tasks can be used as context managers.") 

243 SetupTeardownContext.push_setup_teardown_task(self) 

244 return SetupTeardownContext 

245 

246 def __exit__(self, exc_type, exc_val, exc_tb): 

247 SetupTeardownContext.set_work_task_roots_and_leaves() 

248 

249 # TODO: Task-SDK -- Should the following methods removed? 

250 # get_template_env 

251 # _render 

252 def get_template_env(self, dag: DAG | None = None) -> jinja2.Environment: 

253 """Get the template environment for rendering templates.""" 

254 if dag is None: 

255 dag = self.get_dag() 

256 return super().get_template_env(dag=dag) 

257 

258 def _render(self, template, context, dag: DAG | None = None): 

259 if dag is None: 

260 dag = self.get_dag() 

261 return super()._render(template, context, dag=dag) 

262 

263 def _do_render_template_fields( 

264 self, 

265 parent: Any, 

266 template_fields: Iterable[str], 

267 context: Context, 

268 jinja_env: jinja2.Environment, 

269 seen_oids: set[int], 

270 ) -> None: 

271 """Override the base to use custom error logging.""" 

272 for attr_name in template_fields: 

273 try: 

274 value = getattr(parent, attr_name) 

275 except AttributeError: 

276 raise AttributeError( 

277 f"{attr_name!r} is configured as a template field " 

278 f"but {parent.task_type} does not have this attribute." 

279 ) 

280 try: 

281 if not value: 

282 continue 

283 except Exception: 

284 # This may happen if the templated field points to a class which does not support `__bool__`, 

285 # such as Pandas DataFrames: 

286 # https://github.com/pandas-dev/pandas/blob/9135c3aaf12d26f857fcc787a5b64d521c51e379/pandas/core/generic.py#L1465 

287 log.info( 

288 "Unable to check if the value of type '%s' is False for task '%s', field '%s'.", 

289 type(value).__name__, 

290 self.task_id, 

291 attr_name, 

292 ) 

293 # We may still want to render custom classes which do not support __bool__ 

294 pass 

295 

296 try: 

297 if callable(value): 

298 rendered_content = value(context=context, jinja_env=jinja_env) 

299 else: 

300 rendered_content = self.render_template(value, context, jinja_env, seen_oids) 

301 except Exception: 

302 # Mask sensitive values in the template before logging 

303 from airflow.sdk._shared.secrets_masker import redact 

304 

305 masked_value = redact(value) 

306 log.exception( 

307 "Exception rendering Jinja template for task '%s', field '%s'. Template: %r", 

308 self.task_id, 

309 attr_name, 

310 masked_value, 

311 ) 

312 raise 

313 else: 

314 setattr(parent, attr_name, rendered_content) 

315 

316 def _iter_all_mapped_downstreams(self) -> Iterator[MappedOperator | MappedTaskGroup]: 

317 """ 

318 Return mapped nodes that are direct dependencies of the current task. 

319 

320 For now, this walks the entire Dag to find mapped nodes that has this 

321 current task as an upstream. We cannot use ``downstream_list`` since it 

322 only contains operators, not task groups. In the future, we should 

323 provide a way to record a Dag node's all downstream nodes instead. 

324 

325 Note that this does not guarantee the returned tasks actually use the 

326 current task for task mapping, but only checks those task are mapped 

327 operators, and are downstreams of the current task. 

328 

329 To get a list of tasks that uses the current task for task mapping, use 

330 :meth:`iter_mapped_dependants` instead. 

331 """ 

332 from airflow.sdk.definitions.mappedoperator import MappedOperator 

333 from airflow.sdk.definitions.taskgroup import MappedTaskGroup, TaskGroup 

334 

335 def _walk_group(group: TaskGroup) -> Iterable[tuple[str, DAGNode]]: 

336 """ 

337 Recursively walk children in a task group. 

338 

339 This yields all direct children (including both tasks and task 

340 groups), and all children of any task groups. 

341 """ 

342 for key, child in group.children.items(): 

343 yield key, child 

344 if isinstance(child, TaskGroup): 

345 yield from _walk_group(child) 

346 

347 dag = self.get_dag() 

348 if not dag: 

349 raise RuntimeError("Cannot check for mapped dependants when not attached to a Dag") 

350 for key, child in _walk_group(dag.task_group): 

351 if key == self.node_id: 

352 continue 

353 if not isinstance(child, (MappedOperator, MappedTaskGroup)): 

354 continue 

355 if self.node_id in child.upstream_task_ids: 

356 yield child 

357 

358 def iter_mapped_dependants(self) -> Iterator[MappedOperator | MappedTaskGroup]: 

359 """ 

360 Return mapped nodes that depend on the current task the expansion. 

361 

362 For now, this walks the entire Dag to find mapped nodes that has this 

363 current task as an upstream. We cannot use ``downstream_list`` since it 

364 only contains operators, not task groups. In the future, we should 

365 provide a way to record a Dag node's all downstream nodes instead. 

366 """ 

367 return ( 

368 downstream 

369 for downstream in self._iter_all_mapped_downstreams() 

370 if any(p.node_id == self.node_id for p in downstream.iter_mapped_dependencies()) 

371 ) 

372 

373 def iter_mapped_task_groups(self) -> Iterator[MappedTaskGroup]: 

374 """ 

375 Return mapped task groups this task belongs to. 

376 

377 Groups are returned from the innermost to the outmost. 

378 

379 :meta private: 

380 """ 

381 if (group := self.task_group) is None: 

382 return 

383 yield from group.iter_mapped_task_groups() 

384 

385 def get_closest_mapped_task_group(self) -> MappedTaskGroup | None: 

386 """ 

387 Get the mapped task group "closest" to this task in the Dag. 

388 

389 :meta private: 

390 """ 

391 return next(self.iter_mapped_task_groups(), None) 

392 

393 def get_needs_expansion(self) -> bool: 

394 """ 

395 Return true if the task is MappedOperator or is in a mapped task group. 

396 

397 :meta private: 

398 """ 

399 if self._needs_expansion is None: 

400 if self.get_closest_mapped_task_group() is not None: 

401 self._needs_expansion = True 

402 else: 

403 self._needs_expansion = False 

404 return self._needs_expansion