Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/airflow/sdk/definitions/_internal/abstractoperator.py: 47%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

192 statements  

1# 

2# Licensed to the Apache Software Foundation (ASF) under one 

3# or more contributor license agreements. See the NOTICE file 

4# distributed with this work for additional information 

5# regarding copyright ownership. The ASF licenses this file 

6# to you under the Apache License, Version 2.0 (the 

7# "License"); you may not use this file except in compliance 

8# with the License. You may obtain a copy of the License at 

9# 

10# http://www.apache.org/licenses/LICENSE-2.0 

11# 

12# Unless required by applicable law or agreed to in writing, 

13# software distributed under the License is distributed on an 

14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 

15# KIND, either express or implied. See the License for the 

16# specific language governing permissions and limitations 

17# under the License. 

18from __future__ import annotations 

19 

20import datetime 

21import logging 

22from abc import abstractmethod 

23from collections.abc import ( 

24 Callable, 

25 Collection, 

26 Iterable, 

27 Iterator, 

28) 

29from typing import TYPE_CHECKING, Any, ClassVar, TypeAlias 

30 

31import methodtools 

32 

33from airflow.sdk import TriggerRule, WeightRule 

34from airflow.sdk.configuration import conf 

35from airflow.sdk.definitions._internal.mixins import DependencyMixin 

36from airflow.sdk.definitions._internal.node import DAGNode 

37from airflow.sdk.definitions._internal.setup_teardown import SetupTeardownContext 

38from airflow.sdk.definitions._internal.templater import Templater 

39from airflow.sdk.definitions.context import Context 

40 

41if TYPE_CHECKING: 

42 import jinja2 

43 

44 from airflow.sdk.bases.operator import BaseOperator 

45 from airflow.sdk.bases.operatorlink import BaseOperatorLink 

46 from airflow.sdk.definitions.dag import DAG 

47 from airflow.sdk.definitions.mappedoperator import MappedOperator 

48 from airflow.sdk.definitions.taskgroup import MappedTaskGroup 

49 

50TaskStateChangeCallback = Callable[[Context], None] 

51TaskStateChangeCallbackAttrType: TypeAlias = TaskStateChangeCallback | list[TaskStateChangeCallback] | None 

52 

53DEFAULT_OWNER: str = conf.get_mandatory_value("operators", "default_owner") 

54DEFAULT_POOL_SLOTS: int = 1 

55DEFAULT_POOL_NAME = "default_pool" 

56DEFAULT_PRIORITY_WEIGHT: int = 1 

57# Databases do not support arbitrary precision integers, so we need to limit the range of priority weights. 

58# postgres: -2147483648 to +2147483647 (see https://www.postgresql.org/docs/current/datatype-numeric.html) 

59# mysql: -2147483648 to +2147483647 (see https://dev.mysql.com/doc/refman/8.4/en/integer-types.html) 

60# sqlite: -9223372036854775808 to +9223372036854775807 (see https://sqlite.org/datatype3.html) 

61MINIMUM_PRIORITY_WEIGHT: int = -2147483648 

62MAXIMUM_PRIORITY_WEIGHT: int = 2147483647 

63DEFAULT_EXECUTOR: str | None = None 

64DEFAULT_QUEUE: str = conf.get_mandatory_value("operators", "default_queue") 

65DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST: bool = False 

66DEFAULT_WAIT_FOR_PAST_DEPENDS_BEFORE_SKIPPING: bool = False 

67DEFAULT_RETRIES: int = conf.getint("core", "default_task_retries", fallback=0) 

68DEFAULT_RETRY_DELAY: datetime.timedelta = datetime.timedelta( 

69 seconds=conf.getint("core", "default_task_retry_delay", fallback=300) 

70) 

71DEFAULT_RETRY_DELAY_MULTIPLIER: float = 2.0 

72MAX_RETRY_DELAY: int = conf.getint("core", "max_task_retry_delay", fallback=24 * 60 * 60) 

73 

74# TODO: Task-SDK -- these defaults should be overridable from the Airflow config 

75DEFAULT_TRIGGER_RULE: TriggerRule = TriggerRule.ALL_SUCCESS 

76DEFAULT_WEIGHT_RULE: WeightRule = WeightRule( 

77 conf.get("core", "default_task_weight_rule", fallback=WeightRule.DOWNSTREAM) 

78) 

79DEFAULT_TASK_EXECUTION_TIMEOUT: datetime.timedelta | None = conf.gettimedelta( 

80 "core", "default_task_execution_timeout" 

81) 

82 

83log = logging.getLogger(__name__) 

84 

85 

86class NotMapped(Exception): 

87 """Raise if a task is neither mapped nor has any parent mapped groups.""" 

88 

89 

90class AbstractOperator(Templater, DAGNode): 

91 """ 

92 Common implementation for operators, including unmapped and mapped. 

93 

94 This base class is more about sharing implementations, not defining a common 

95 interface. Unfortunately it's difficult to use this as the common base class 

96 for typing due to BaseOperator carrying too much historical baggage. 

97 

98 The union type ``from airflow.models.operator import Operator`` is easier 

99 to use for typing purposes. 

100 

101 :meta private: 

102 """ 

103 

104 operator_class: type[BaseOperator] | dict[str, Any] 

105 

106 priority_weight: int 

107 

108 # Defines the operator level extra links. 

109 operator_extra_links: Collection[BaseOperatorLink] = () 

110 

111 owner: str 

112 task_id: str 

113 

114 outlets: list 

115 inlets: list 

116 

117 trigger_rule: TriggerRule 

118 _needs_expansion: bool | None = None 

119 _on_failure_fail_dagrun = False 

120 is_setup: bool = False 

121 is_teardown: bool = False 

122 

123 HIDE_ATTRS_FROM_UI: ClassVar[frozenset[str]] = frozenset( 

124 ( 

125 "log", 

126 "dag", # We show dag_id, don't need to show this too 

127 "node_id", # Duplicates task_id 

128 "task_group", # Doesn't have a useful repr, no point showing in UI 

129 "inherits_from_empty_operator", # impl detail 

130 "inherits_from_skipmixin", # impl detail 

131 # Decide whether to start task execution from triggerer 

132 "start_trigger_args", 

133 "start_from_trigger", 

134 # For compatibility with TG, for operators these are just the current task, no point showing 

135 "roots", 

136 "leaves", 

137 # These lists are already shown via *_task_ids 

138 "upstream_list", 

139 "downstream_list", 

140 # Not useful, implementation detail, already shown elsewhere 

141 "global_operator_extra_link_dict", 

142 "operator_extra_link_dict", 

143 ) 

144 ) 

145 

146 @property 

147 def task_type(self) -> str: 

148 raise NotImplementedError() 

149 

150 @property 

151 def operator_name(self) -> str: 

152 raise NotImplementedError() 

153 

154 _is_sensor: bool = False 

155 _is_mapped: bool = False 

156 _can_skip_downstream: bool = False 

157 

158 @property 

159 def dag_id(self) -> str: 

160 """Returns dag id if it has one or an adhoc + owner.""" 

161 dag = self.get_dag() 

162 if dag: 

163 return dag.dag_id 

164 return f"adhoc_{self.owner}" 

165 

166 @property 

167 def node_id(self) -> str: 

168 return self.task_id 

169 

170 @property 

171 @abstractmethod 

172 def task_display_name(self) -> str: ... 

173 

174 @property 

175 def is_mapped(self): 

176 return self._is_mapped 

177 

178 @property 

179 def label(self) -> str | None: 

180 if self.task_display_name and self.task_display_name != self.task_id: 

181 return self.task_display_name 

182 # Prefix handling if no display is given is cloned from taskmixin for compatibility 

183 tg = self.task_group 

184 if tg and tg.node_id and tg.prefix_group_id: 

185 # "task_group_id.task_id" -> "task_id" 

186 return self.task_id[len(tg.node_id) + 1 :] 

187 return self.task_id 

188 

189 @property 

190 def on_failure_fail_dagrun(self): 

191 """ 

192 Whether the operator should fail the dagrun on failure. 

193 

194 :meta private: 

195 """ 

196 return self._on_failure_fail_dagrun 

197 

198 @on_failure_fail_dagrun.setter 

199 def on_failure_fail_dagrun(self, value): 

200 """ 

201 Setter for on_failure_fail_dagrun property. 

202 

203 :meta private: 

204 """ 

205 if value is True and self.is_teardown is not True: 

206 raise ValueError( 

207 f"Cannot set task on_failure_fail_dagrun for " 

208 f"'{self.task_id}' because it is not a teardown task." 

209 ) 

210 self._on_failure_fail_dagrun = value 

211 

212 @property 

213 def inherits_from_empty_operator(self): 

214 """Used to determine if an Operator is inherited from EmptyOperator.""" 

215 # This looks like `isinstance(self, EmptyOperator) would work, but this also 

216 # needs to cope when `self` is a Serialized instance of a EmptyOperator or one 

217 # of its subclasses (which don't inherit from anything but BaseOperator). 

218 return getattr(self, "_is_empty", False) 

219 

220 @property 

221 def inherits_from_skipmixin(self): 

222 """Used to determine if an Operator is inherited from SkipMixin or its subclasses (e.g., BranchMixin).""" 

223 return getattr(self, "_can_skip_downstream", False) 

224 

225 def as_setup(self): 

226 self.is_setup = True 

227 return self 

228 

229 def as_teardown( 

230 self, 

231 *, 

232 setups: BaseOperator | Iterable[BaseOperator] | None = None, 

233 on_failure_fail_dagrun: bool | None = None, 

234 ): 

235 self.is_teardown = True 

236 self.trigger_rule = TriggerRule.ALL_DONE_SETUP_SUCCESS 

237 if on_failure_fail_dagrun is not None: 

238 self.on_failure_fail_dagrun = on_failure_fail_dagrun 

239 if setups is not None: 

240 setups = [setups] if isinstance(setups, DependencyMixin) else setups 

241 for s in setups: 

242 s.is_setup = True 

243 s >> self 

244 return self 

245 

246 def __enter__(self): 

247 if not self.is_setup and not self.is_teardown: 

248 raise RuntimeError("Only setup/teardown tasks can be used as context managers.") 

249 SetupTeardownContext.push_setup_teardown_task(self) 

250 return SetupTeardownContext 

251 

252 def __exit__(self, exc_type, exc_val, exc_tb): 

253 SetupTeardownContext.set_work_task_roots_and_leaves() 

254 

255 # TODO: Task-SDK -- Should the following methods removed? 

256 # get_template_env 

257 # _render 

258 def get_template_env(self, dag: DAG | None = None) -> jinja2.Environment: 

259 """Get the template environment for rendering templates.""" 

260 if dag is None: 

261 dag = self.get_dag() 

262 return super().get_template_env(dag=dag) 

263 

264 def _render(self, template, context, dag: DAG | None = None): 

265 if dag is None: 

266 dag = self.get_dag() 

267 return super()._render(template, context, dag=dag) 

268 

269 def _do_render_template_fields( 

270 self, 

271 parent: Any, 

272 template_fields: Iterable[str], 

273 context: Context, 

274 jinja_env: jinja2.Environment, 

275 seen_oids: set[int], 

276 ) -> None: 

277 """Override the base to use custom error logging.""" 

278 for attr_name in template_fields: 

279 try: 

280 value = getattr(parent, attr_name) 

281 except AttributeError: 

282 raise AttributeError( 

283 f"{attr_name!r} is configured as a template field " 

284 f"but {parent.task_type} does not have this attribute." 

285 ) 

286 try: 

287 if not value: 

288 continue 

289 except Exception: 

290 # This may happen if the templated field points to a class which does not support `__bool__`, 

291 # such as Pandas DataFrames: 

292 # https://github.com/pandas-dev/pandas/blob/9135c3aaf12d26f857fcc787a5b64d521c51e379/pandas/core/generic.py#L1465 

293 log.info( 

294 "Unable to check if the value of type '%s' is False for task '%s', field '%s'.", 

295 type(value).__name__, 

296 self.task_id, 

297 attr_name, 

298 ) 

299 # We may still want to render custom classes which do not support __bool__ 

300 pass 

301 

302 try: 

303 if callable(value): 

304 rendered_content = value(context=context, jinja_env=jinja_env) 

305 else: 

306 rendered_content = self.render_template(value, context, jinja_env, seen_oids) 

307 except Exception: 

308 # Mask sensitive values in the template before logging 

309 from airflow.sdk._shared.secrets_masker import redact 

310 

311 masked_value = redact(value) 

312 log.exception( 

313 "Exception rendering Jinja template for task '%s', field '%s'. Template: %r", 

314 self.task_id, 

315 attr_name, 

316 masked_value, 

317 ) 

318 raise 

319 else: 

320 setattr(parent, attr_name, rendered_content) 

321 

322 def _iter_all_mapped_downstreams(self) -> Iterator[MappedOperator | MappedTaskGroup]: 

323 """ 

324 Return mapped nodes that are direct dependencies of the current task. 

325 

326 For now, this walks the entire Dag to find mapped nodes that has this 

327 current task as an upstream. We cannot use ``downstream_list`` since it 

328 only contains operators, not task groups. In the future, we should 

329 provide a way to record a Dag node's all downstream nodes instead. 

330 

331 Note that this does not guarantee the returned tasks actually use the 

332 current task for task mapping, but only checks those task are mapped 

333 operators, and are downstreams of the current task. 

334 

335 To get a list of tasks that uses the current task for task mapping, use 

336 :meth:`iter_mapped_dependants` instead. 

337 """ 

338 from airflow.sdk.definitions.mappedoperator import MappedOperator 

339 from airflow.sdk.definitions.taskgroup import MappedTaskGroup, TaskGroup 

340 

341 def _walk_group(group: TaskGroup) -> Iterable[tuple[str, DAGNode]]: 

342 """ 

343 Recursively walk children in a task group. 

344 

345 This yields all direct children (including both tasks and task 

346 groups), and all children of any task groups. 

347 """ 

348 for key, child in group.children.items(): 

349 yield key, child 

350 if isinstance(child, TaskGroup): 

351 yield from _walk_group(child) 

352 

353 dag = self.get_dag() 

354 if not dag: 

355 raise RuntimeError("Cannot check for mapped dependants when not attached to a Dag") 

356 for key, child in _walk_group(dag.task_group): 

357 if key == self.node_id: 

358 continue 

359 if not isinstance(child, (MappedOperator, MappedTaskGroup)): 

360 continue 

361 if self.node_id in child.upstream_task_ids: 

362 yield child 

363 

364 def iter_mapped_dependants(self) -> Iterator[MappedOperator | MappedTaskGroup]: 

365 """ 

366 Return mapped nodes that depend on the current task the expansion. 

367 

368 For now, this walks the entire Dag to find mapped nodes that has this 

369 current task as an upstream. We cannot use ``downstream_list`` since it 

370 only contains operators, not task groups. In the future, we should 

371 provide a way to record a Dag node's all downstream nodes instead. 

372 """ 

373 return ( 

374 downstream 

375 for downstream in self._iter_all_mapped_downstreams() 

376 if any(p.node_id == self.node_id for p in downstream.iter_mapped_dependencies()) 

377 ) 

378 

379 def iter_mapped_task_groups(self) -> Iterator[MappedTaskGroup]: 

380 """ 

381 Return mapped task groups this task belongs to. 

382 

383 Groups are returned from the innermost to the outmost. 

384 

385 :meta private: 

386 """ 

387 if (group := self.task_group) is None: 

388 return 

389 yield from group.iter_mapped_task_groups() 

390 

391 def get_closest_mapped_task_group(self) -> MappedTaskGroup | None: 

392 """ 

393 Get the mapped task group "closest" to this task in the Dag. 

394 

395 :meta private: 

396 """ 

397 return next(self.iter_mapped_task_groups(), None) 

398 

399 def get_needs_expansion(self) -> bool: 

400 """ 

401 Return true if the task is MappedOperator or is in a mapped task group. 

402 

403 :meta private: 

404 """ 

405 if self._needs_expansion is None: 

406 if self.get_closest_mapped_task_group() is not None: 

407 self._needs_expansion = True 

408 else: 

409 self._needs_expansion = False 

410 return self._needs_expansion 

411 

412 @methodtools.lru_cache(maxsize=None) 

413 def get_parse_time_mapped_ti_count(self) -> int: 

414 """ 

415 Return the number of mapped task instances that can be created on Dag run creation. 

416 

417 This only considers literal mapped arguments, and would return *None* 

418 when any non-literal values are used for mapping. 

419 

420 :raise NotFullyPopulated: If non-literal mapped arguments are encountered. 

421 :raise NotMapped: If the operator is neither mapped, nor has any parent 

422 mapped task groups. 

423 :return: Total number of mapped TIs this task should have. 

424 """ 

425 group = self.get_closest_mapped_task_group() 

426 if group is None: 

427 raise NotMapped() 

428 return group.get_parse_time_mapped_ti_count()