1#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18from __future__ import annotations
19
20import datetime
21import logging
22from abc import abstractmethod
23from collections.abc import (
24 Callable,
25 Collection,
26 Iterable,
27 Iterator,
28)
29from typing import TYPE_CHECKING, Any, ClassVar, TypeAlias
30
31from airflow.sdk import TriggerRule, WeightRule
32from airflow.sdk.configuration import conf
33from airflow.sdk.definitions._internal.mixins import DependencyMixin
34from airflow.sdk.definitions._internal.node import DAGNode
35from airflow.sdk.definitions._internal.setup_teardown import SetupTeardownContext
36from airflow.sdk.definitions._internal.templater import Templater
37from airflow.sdk.definitions.context import Context
38
39if TYPE_CHECKING:
40 import jinja2
41
42 from airflow.sdk.bases.operator import BaseOperator
43 from airflow.sdk.bases.operatorlink import BaseOperatorLink
44 from airflow.sdk.definitions.dag import DAG
45 from airflow.sdk.definitions.mappedoperator import MappedOperator
46 from airflow.sdk.definitions.taskgroup import MappedTaskGroup
47
48TaskStateChangeCallback = Callable[[Context], None]
49TaskStateChangeCallbackAttrType: TypeAlias = TaskStateChangeCallback | list[TaskStateChangeCallback] | None
50
51DEFAULT_OWNER: str = conf.get_mandatory_value("operators", "default_owner")
52DEFAULT_POOL_SLOTS: int = 1
53DEFAULT_POOL_NAME = "default_pool"
54DEFAULT_PRIORITY_WEIGHT: int = 1
55# Databases do not support arbitrary precision integers, so we need to limit the range of priority weights.
56# postgres: -2147483648 to +2147483647 (see https://www.postgresql.org/docs/current/datatype-numeric.html)
57# mysql: -2147483648 to +2147483647 (see https://dev.mysql.com/doc/refman/8.4/en/integer-types.html)
58# sqlite: -9223372036854775808 to +9223372036854775807 (see https://sqlite.org/datatype3.html)
59MINIMUM_PRIORITY_WEIGHT: int = -2147483648
60MAXIMUM_PRIORITY_WEIGHT: int = 2147483647
61DEFAULT_EXECUTOR: str | None = None
62DEFAULT_QUEUE: str = conf.get_mandatory_value("operators", "default_queue")
63DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST: bool = False
64DEFAULT_WAIT_FOR_PAST_DEPENDS_BEFORE_SKIPPING: bool = False
65DEFAULT_RETRIES: int = conf.getint("core", "default_task_retries", fallback=0)
66DEFAULT_RETRY_DELAY: datetime.timedelta = datetime.timedelta(
67 seconds=conf.getint("core", "default_task_retry_delay", fallback=300)
68)
69MAX_RETRY_DELAY: int = conf.getint("core", "max_task_retry_delay", fallback=24 * 60 * 60)
70
71# TODO: Task-SDK -- these defaults should be overridable from the Airflow config
72DEFAULT_TRIGGER_RULE: TriggerRule = TriggerRule.ALL_SUCCESS
73DEFAULT_WEIGHT_RULE: WeightRule = WeightRule(
74 conf.get("core", "default_task_weight_rule", fallback=WeightRule.DOWNSTREAM)
75)
76DEFAULT_TASK_EXECUTION_TIMEOUT: datetime.timedelta | None = conf.gettimedelta(
77 "core", "default_task_execution_timeout"
78)
79
80log = logging.getLogger(__name__)
81
82
83class AbstractOperator(Templater, DAGNode):
84 """
85 Common implementation for operators, including unmapped and mapped.
86
87 This base class is more about sharing implementations, not defining a common
88 interface. Unfortunately it's difficult to use this as the common base class
89 for typing due to BaseOperator carrying too much historical baggage.
90
91 The union type ``from airflow.models.operator import Operator`` is easier
92 to use for typing purposes.
93
94 :meta private:
95 """
96
97 operator_class: type[BaseOperator] | dict[str, Any]
98
99 priority_weight: int
100
101 # Defines the operator level extra links.
102 operator_extra_links: Collection[BaseOperatorLink] = ()
103
104 owner: str
105 task_id: str
106
107 outlets: list
108 inlets: list
109
110 trigger_rule: TriggerRule
111 _needs_expansion: bool | None = None
112 _on_failure_fail_dagrun = False
113 is_setup: bool = False
114 is_teardown: bool = False
115
116 HIDE_ATTRS_FROM_UI: ClassVar[frozenset[str]] = frozenset(
117 (
118 "log",
119 "dag", # We show dag_id, don't need to show this too
120 "node_id", # Duplicates task_id
121 "task_group", # Doesn't have a useful repr, no point showing in UI
122 "inherits_from_empty_operator", # impl detail
123 "inherits_from_skipmixin", # impl detail
124 # Decide whether to start task execution from triggerer
125 "start_trigger_args",
126 "start_from_trigger",
127 # For compatibility with TG, for operators these are just the current task, no point showing
128 "roots",
129 "leaves",
130 # These lists are already shown via *_task_ids
131 "upstream_list",
132 "downstream_list",
133 # Not useful, implementation detail, already shown elsewhere
134 "global_operator_extra_link_dict",
135 "operator_extra_link_dict",
136 )
137 )
138
139 @property
140 def task_type(self) -> str:
141 raise NotImplementedError()
142
143 @property
144 def operator_name(self) -> str:
145 raise NotImplementedError()
146
147 _is_sensor: bool = False
148 _is_mapped: bool = False
149 _can_skip_downstream: bool = False
150
151 @property
152 def dag_id(self) -> str:
153 """Returns dag id if it has one or an adhoc + owner."""
154 dag = self.get_dag()
155 if dag:
156 return dag.dag_id
157 return f"adhoc_{self.owner}"
158
159 @property
160 def node_id(self) -> str:
161 return self.task_id
162
163 @property
164 @abstractmethod
165 def task_display_name(self) -> str: ...
166
167 @property
168 def is_mapped(self):
169 return self._is_mapped
170
171 @property
172 def label(self) -> str | None:
173 if self.task_display_name and self.task_display_name != self.task_id:
174 return self.task_display_name
175 # Prefix handling if no display is given is cloned from taskmixin for compatibility
176 tg = self.task_group
177 if tg and tg.node_id and tg.prefix_group_id:
178 # "task_group_id.task_id" -> "task_id"
179 return self.task_id[len(tg.node_id) + 1 :]
180 return self.task_id
181
182 @property
183 def on_failure_fail_dagrun(self):
184 """
185 Whether the operator should fail the dagrun on failure.
186
187 :meta private:
188 """
189 return self._on_failure_fail_dagrun
190
191 @on_failure_fail_dagrun.setter
192 def on_failure_fail_dagrun(self, value):
193 """
194 Setter for on_failure_fail_dagrun property.
195
196 :meta private:
197 """
198 if value is True and self.is_teardown is not True:
199 raise ValueError(
200 f"Cannot set task on_failure_fail_dagrun for "
201 f"'{self.task_id}' because it is not a teardown task."
202 )
203 self._on_failure_fail_dagrun = value
204
205 @property
206 def inherits_from_empty_operator(self):
207 """Used to determine if an Operator is inherited from EmptyOperator."""
208 # This looks like `isinstance(self, EmptyOperator) would work, but this also
209 # needs to cope when `self` is a Serialized instance of a EmptyOperator or one
210 # of its subclasses (which don't inherit from anything but BaseOperator).
211 return getattr(self, "_is_empty", False)
212
213 @property
214 def inherits_from_skipmixin(self):
215 """Used to determine if an Operator is inherited from SkipMixin or its subclasses (e.g., BranchMixin)."""
216 return getattr(self, "_can_skip_downstream", False)
217
218 def as_setup(self):
219 self.is_setup = True
220 return self
221
222 def as_teardown(
223 self,
224 *,
225 setups: BaseOperator | Iterable[BaseOperator] | None = None,
226 on_failure_fail_dagrun: bool | None = None,
227 ):
228 self.is_teardown = True
229 self.trigger_rule = TriggerRule.ALL_DONE_SETUP_SUCCESS
230 if on_failure_fail_dagrun is not None:
231 self.on_failure_fail_dagrun = on_failure_fail_dagrun
232 if setups is not None:
233 setups = [setups] if isinstance(setups, DependencyMixin) else setups
234 for s in setups:
235 s.is_setup = True
236 s >> self
237 return self
238
239 def __enter__(self):
240 if not self.is_setup and not self.is_teardown:
241 raise RuntimeError("Only setup/teardown tasks can be used as context managers.")
242 SetupTeardownContext.push_setup_teardown_task(self)
243 return SetupTeardownContext
244
245 def __exit__(self, exc_type, exc_val, exc_tb):
246 SetupTeardownContext.set_work_task_roots_and_leaves()
247
248 # TODO: Task-SDK -- Should the following methods removed?
249 # get_template_env
250 # _render
251 def get_template_env(self, dag: DAG | None = None) -> jinja2.Environment:
252 """Get the template environment for rendering templates."""
253 if dag is None:
254 dag = self.get_dag()
255 return super().get_template_env(dag=dag)
256
257 def _render(self, template, context, dag: DAG | None = None):
258 if dag is None:
259 dag = self.get_dag()
260 return super()._render(template, context, dag=dag)
261
262 def _do_render_template_fields(
263 self,
264 parent: Any,
265 template_fields: Iterable[str],
266 context: Context,
267 jinja_env: jinja2.Environment,
268 seen_oids: set[int],
269 ) -> None:
270 """Override the base to use custom error logging."""
271 for attr_name in template_fields:
272 try:
273 value = getattr(parent, attr_name)
274 except AttributeError:
275 raise AttributeError(
276 f"{attr_name!r} is configured as a template field "
277 f"but {parent.task_type} does not have this attribute."
278 )
279 try:
280 if not value:
281 continue
282 except Exception:
283 # This may happen if the templated field points to a class which does not support `__bool__`,
284 # such as Pandas DataFrames:
285 # https://github.com/pandas-dev/pandas/blob/9135c3aaf12d26f857fcc787a5b64d521c51e379/pandas/core/generic.py#L1465
286 log.info(
287 "Unable to check if the value of type '%s' is False for task '%s', field '%s'.",
288 type(value).__name__,
289 self.task_id,
290 attr_name,
291 )
292 # We may still want to render custom classes which do not support __bool__
293 pass
294
295 try:
296 if callable(value):
297 rendered_content = value(context=context, jinja_env=jinja_env)
298 else:
299 rendered_content = self.render_template(value, context, jinja_env, seen_oids)
300 except Exception:
301 # Mask sensitive values in the template before logging
302 from airflow.sdk._shared.secrets_masker import redact
303
304 masked_value = redact(value)
305 log.exception(
306 "Exception rendering Jinja template for task '%s', field '%s'. Template: %r",
307 self.task_id,
308 attr_name,
309 masked_value,
310 )
311 raise
312 else:
313 setattr(parent, attr_name, rendered_content)
314
315 def _iter_all_mapped_downstreams(self) -> Iterator[MappedOperator | MappedTaskGroup]:
316 """
317 Return mapped nodes that are direct dependencies of the current task.
318
319 For now, this walks the entire Dag to find mapped nodes that has this
320 current task as an upstream. We cannot use ``downstream_list`` since it
321 only contains operators, not task groups. In the future, we should
322 provide a way to record a Dag node's all downstream nodes instead.
323
324 Note that this does not guarantee the returned tasks actually use the
325 current task for task mapping, but only checks those task are mapped
326 operators, and are downstreams of the current task.
327
328 To get a list of tasks that uses the current task for task mapping, use
329 :meth:`iter_mapped_dependants` instead.
330 """
331 from airflow.sdk.definitions.mappedoperator import MappedOperator
332 from airflow.sdk.definitions.taskgroup import MappedTaskGroup, TaskGroup
333
334 def _walk_group(group: TaskGroup) -> Iterable[tuple[str, DAGNode]]:
335 """
336 Recursively walk children in a task group.
337
338 This yields all direct children (including both tasks and task
339 groups), and all children of any task groups.
340 """
341 for key, child in group.children.items():
342 yield key, child
343 if isinstance(child, TaskGroup):
344 yield from _walk_group(child)
345
346 dag = self.get_dag()
347 if not dag:
348 raise RuntimeError("Cannot check for mapped dependants when not attached to a Dag")
349 for key, child in _walk_group(dag.task_group):
350 if key == self.node_id:
351 continue
352 if not isinstance(child, (MappedOperator, MappedTaskGroup)):
353 continue
354 if self.node_id in child.upstream_task_ids:
355 yield child
356
357 def iter_mapped_dependants(self) -> Iterator[MappedOperator | MappedTaskGroup]:
358 """
359 Return mapped nodes that depend on the current task the expansion.
360
361 For now, this walks the entire Dag to find mapped nodes that has this
362 current task as an upstream. We cannot use ``downstream_list`` since it
363 only contains operators, not task groups. In the future, we should
364 provide a way to record a Dag node's all downstream nodes instead.
365 """
366 return (
367 downstream
368 for downstream in self._iter_all_mapped_downstreams()
369 if any(p.node_id == self.node_id for p in downstream.iter_mapped_dependencies())
370 )
371
372 def iter_mapped_task_groups(self) -> Iterator[MappedTaskGroup]:
373 """
374 Return mapped task groups this task belongs to.
375
376 Groups are returned from the innermost to the outmost.
377
378 :meta private:
379 """
380 if (group := self.task_group) is None:
381 return
382 yield from group.iter_mapped_task_groups()
383
384 def get_closest_mapped_task_group(self) -> MappedTaskGroup | None:
385 """
386 Get the mapped task group "closest" to this task in the Dag.
387
388 :meta private:
389 """
390 return next(self.iter_mapped_task_groups(), None)
391
392 def get_needs_expansion(self) -> bool:
393 """
394 Return true if the task is MappedOperator or is in a mapped task group.
395
396 :meta private:
397 """
398 if self._needs_expansion is None:
399 if self.get_closest_mapped_task_group() is not None:
400 self._needs_expansion = True
401 else:
402 self._needs_expansion = False
403 return self._needs_expansion