1#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18from __future__ import annotations
19
20import datetime
21import logging
22from abc import abstractmethod
23from collections.abc import (
24 Callable,
25 Collection,
26 Iterable,
27 Iterator,
28)
29from typing import TYPE_CHECKING, Any, ClassVar, TypeAlias
30
31from airflow.sdk import TriggerRule, WeightRule
32from airflow.sdk.configuration import conf
33from airflow.sdk.definitions._internal.mixins import DependencyMixin
34from airflow.sdk.definitions._internal.node import DAGNode
35from airflow.sdk.definitions._internal.setup_teardown import SetupTeardownContext
36from airflow.sdk.definitions._internal.templater import Templater
37from airflow.sdk.definitions.context import Context
38
39if TYPE_CHECKING:
40 import jinja2
41
42 from airflow.sdk.bases.operator import BaseOperator
43 from airflow.sdk.bases.operatorlink import BaseOperatorLink
44 from airflow.sdk.definitions.dag import DAG
45 from airflow.sdk.definitions.mappedoperator import MappedOperator
46 from airflow.sdk.definitions.taskgroup import MappedTaskGroup
47
48TaskStateChangeCallback = Callable[[Context], None]
49TaskStateChangeCallbackAttrType: TypeAlias = TaskStateChangeCallback | list[TaskStateChangeCallback] | None
50
51DEFAULT_OWNER: str = conf.get_mandatory_value("operators", "default_owner")
52DEFAULT_POOL_SLOTS: int = 1
53DEFAULT_POOL_NAME = "default_pool"
54DEFAULT_PRIORITY_WEIGHT: int = 1
55# Databases do not support arbitrary precision integers, so we need to limit the range of priority weights.
56# postgres: -2147483648 to +2147483647 (see https://www.postgresql.org/docs/current/datatype-numeric.html)
57# mysql: -2147483648 to +2147483647 (see https://dev.mysql.com/doc/refman/8.4/en/integer-types.html)
58# sqlite: -9223372036854775808 to +9223372036854775807 (see https://sqlite.org/datatype3.html)
59MINIMUM_PRIORITY_WEIGHT: int = -2147483648
60MAXIMUM_PRIORITY_WEIGHT: int = 2147483647
61DEFAULT_EXECUTOR: str | None = None
62DEFAULT_QUEUE: str = conf.get_mandatory_value("operators", "default_queue")
63DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST: bool = False
64DEFAULT_WAIT_FOR_PAST_DEPENDS_BEFORE_SKIPPING: bool = False
65DEFAULT_RETRIES: int = conf.getint("core", "default_task_retries", fallback=0)
66DEFAULT_RETRY_DELAY: datetime.timedelta = datetime.timedelta(
67 seconds=conf.getint("core", "default_task_retry_delay", fallback=300)
68)
69MAX_RETRY_DELAY: int = conf.getint("core", "max_task_retry_delay", fallback=24 * 60 * 60)
70
71# TODO: Task-SDK -- these defaults should be overridable from the Airflow config
72DEFAULT_TRIGGER_RULE: TriggerRule = TriggerRule.ALL_SUCCESS
73DEFAULT_WEIGHT_RULE: WeightRule = WeightRule(
74 conf.get("core", "default_task_weight_rule", fallback=WeightRule.DOWNSTREAM)
75)
76DEFAULT_TASK_EXECUTION_TIMEOUT: datetime.timedelta | None = conf.gettimedelta(
77 "core", "default_task_execution_timeout"
78)
79DEFAULT_EMAIL_ON_FAILURE: bool = conf.getboolean("email", "default_email_on_failure", fallback=True)
80DEFAULT_EMAIL_ON_RETRY: bool = conf.getboolean("email", "default_email_on_retry", fallback=True)
81log = logging.getLogger(__name__)
82
83
84class AbstractOperator(Templater, DAGNode):
85 """
86 Common implementation for operators, including unmapped and mapped.
87
88 This base class is more about sharing implementations, not defining a common
89 interface. Unfortunately it's difficult to use this as the common base class
90 for typing due to BaseOperator carrying too much historical baggage.
91
92 The union type ``from airflow.models.operator import Operator`` is easier
93 to use for typing purposes.
94
95 :meta private:
96 """
97
98 operator_class: type[BaseOperator] | dict[str, Any]
99
100 priority_weight: int
101
102 # Defines the operator level extra links.
103 operator_extra_links: Collection[BaseOperatorLink] = ()
104
105 owner: str
106 task_id: str
107
108 outlets: list
109 inlets: list
110
111 trigger_rule: TriggerRule
112 _needs_expansion: bool | None = None
113 _on_failure_fail_dagrun = False
114 is_setup: bool = False
115 is_teardown: bool = False
116
117 HIDE_ATTRS_FROM_UI: ClassVar[frozenset[str]] = frozenset(
118 (
119 "log",
120 "dag", # We show dag_id, don't need to show this too
121 "node_id", # Duplicates task_id
122 "task_group", # Doesn't have a useful repr, no point showing in UI
123 "inherits_from_empty_operator", # impl detail
124 "inherits_from_skipmixin", # impl detail
125 # Decide whether to start task execution from triggerer
126 "start_trigger_args",
127 "start_from_trigger",
128 # For compatibility with TG, for operators these are just the current task, no point showing
129 "roots",
130 "leaves",
131 # These lists are already shown via *_task_ids
132 "upstream_list",
133 "downstream_list",
134 # Not useful, implementation detail, already shown elsewhere
135 "global_operator_extra_link_dict",
136 "operator_extra_link_dict",
137 )
138 )
139
140 @property
141 def is_async(self) -> bool:
142 return False
143
144 @property
145 def task_type(self) -> str:
146 raise NotImplementedError()
147
148 @property
149 def operator_name(self) -> str:
150 raise NotImplementedError()
151
152 _is_sensor: bool = False
153 _is_mapped: bool = False
154 _can_skip_downstream: bool = False
155
156 @property
157 def dag_id(self) -> str:
158 """Returns dag id if it has one or an adhoc + owner."""
159 dag = self.get_dag()
160 if dag:
161 return dag.dag_id
162 return f"adhoc_{self.owner}"
163
164 @property
165 def node_id(self) -> str:
166 return self.task_id
167
168 @property
169 @abstractmethod
170 def task_display_name(self) -> str: ...
171
172 @property
173 def is_mapped(self):
174 return self._is_mapped
175
176 @property
177 def label(self) -> str | None:
178 if self.task_display_name and self.task_display_name != self.task_id:
179 return self.task_display_name
180 # Prefix handling if no display is given is cloned from taskmixin for compatibility
181 tg = self.task_group
182 if tg and tg.node_id and tg.prefix_group_id:
183 # "task_group_id.task_id" -> "task_id"
184 return self.task_id[len(tg.node_id) + 1 :]
185 return self.task_id
186
187 @property
188 def on_failure_fail_dagrun(self):
189 """
190 Whether the operator should fail the dagrun on failure.
191
192 :meta private:
193 """
194 return self._on_failure_fail_dagrun
195
196 @on_failure_fail_dagrun.setter
197 def on_failure_fail_dagrun(self, value):
198 """
199 Setter for on_failure_fail_dagrun property.
200
201 :meta private:
202 """
203 if value is True and self.is_teardown is not True:
204 raise ValueError(
205 f"Cannot set task on_failure_fail_dagrun for "
206 f"'{self.task_id}' because it is not a teardown task."
207 )
208 self._on_failure_fail_dagrun = value
209
210 @property
211 def inherits_from_empty_operator(self):
212 """Used to determine if an Operator is inherited from EmptyOperator."""
213 # This looks like `isinstance(self, EmptyOperator) would work, but this also
214 # needs to cope when `self` is a Serialized instance of a EmptyOperator or one
215 # of its subclasses (which don't inherit from anything but BaseOperator).
216 return getattr(self, "_is_empty", False)
217
218 @property
219 def inherits_from_skipmixin(self):
220 """Used to determine if an Operator is inherited from SkipMixin or its subclasses (e.g., BranchMixin)."""
221 return getattr(self, "_can_skip_downstream", False)
222
223 def as_setup(self):
224 self.is_setup = True
225 return self
226
227 def as_teardown(
228 self,
229 *,
230 setups: BaseOperator | Iterable[BaseOperator] | None = None,
231 on_failure_fail_dagrun: bool | None = None,
232 ):
233 self.is_teardown = True
234 self.trigger_rule = TriggerRule.ALL_DONE_SETUP_SUCCESS
235 if on_failure_fail_dagrun is not None:
236 self.on_failure_fail_dagrun = on_failure_fail_dagrun
237 if setups is not None:
238 setups = [setups] if isinstance(setups, DependencyMixin) else setups
239 for s in setups:
240 s.is_setup = True
241 s >> self
242 return self
243
244 def __enter__(self):
245 if not self.is_setup and not self.is_teardown:
246 raise RuntimeError("Only setup/teardown tasks can be used as context managers.")
247 SetupTeardownContext.push_setup_teardown_task(self)
248 return SetupTeardownContext
249
250 def __exit__(self, exc_type, exc_val, exc_tb):
251 SetupTeardownContext.set_work_task_roots_and_leaves()
252
253 # TODO: Task-SDK -- Should the following methods removed?
254 # get_template_env
255 # _render
256 def get_template_env(self, dag: DAG | None = None) -> jinja2.Environment:
257 """Get the template environment for rendering templates."""
258 from airflow.sdk.definitions._internal.templater import create_template_env
259
260 if dag is None:
261 dag = self.get_dag()
262 # Check if the operator has an explicit native rendering preference
263 render_op_template_as_native_obj = getattr(self, "render_template_as_native_obj", None)
264 if render_op_template_as_native_obj is not None:
265 if dag:
266 # Use dag's template settings (searchpath, macros, filters, etc.)
267 searchpath = [dag.folder]
268 if dag.template_searchpath:
269 searchpath += dag.template_searchpath
270 return create_template_env(
271 native=render_op_template_as_native_obj,
272 searchpath=searchpath,
273 template_undefined=dag.template_undefined,
274 jinja_environment_kwargs=dag.jinja_environment_kwargs,
275 user_defined_macros=dag.user_defined_macros,
276 user_defined_filters=dag.user_defined_filters,
277 )
278 # No dag context available, use minimal template env
279 return create_template_env(native=render_op_template_as_native_obj)
280 # No operator-level override, delegate to parent class
281 return super().get_template_env(dag=dag)
282
283 def _render(self, template, context, dag: DAG | None = None):
284 if dag is None:
285 dag = self.get_dag()
286 return super()._render(template, context, dag=dag)
287
288 def _do_render_template_fields(
289 self,
290 parent: Any,
291 template_fields: Iterable[str],
292 context: Context,
293 jinja_env: jinja2.Environment,
294 seen_oids: set[int],
295 ) -> None:
296 """Override the base to use custom error logging."""
297 for attr_name in template_fields:
298 try:
299 value = getattr(parent, attr_name)
300 except AttributeError:
301 raise AttributeError(
302 f"{attr_name!r} is configured as a template field "
303 f"but {parent.task_type} does not have this attribute."
304 )
305 try:
306 if not value:
307 continue
308 except Exception:
309 # This may happen if the templated field points to a class which does not support `__bool__`,
310 # such as Pandas DataFrames:
311 # https://github.com/pandas-dev/pandas/blob/9135c3aaf12d26f857fcc787a5b64d521c51e379/pandas/core/generic.py#L1465
312 log.info(
313 "Unable to check if the value of type '%s' is False for task '%s', field '%s'.",
314 type(value).__name__,
315 self.task_id,
316 attr_name,
317 )
318 # We may still want to render custom classes which do not support __bool__
319 pass
320
321 try:
322 if callable(value):
323 rendered_content = value(context=context, jinja_env=jinja_env)
324 else:
325 rendered_content = self.render_template(value, context, jinja_env, seen_oids)
326 except Exception:
327 # Mask sensitive values in the template before logging
328 from airflow.sdk._shared.secrets_masker import redact
329
330 masked_value = redact(value)
331 log.exception(
332 "Exception rendering Jinja template for task '%s', field '%s'. Template: %r",
333 self.task_id,
334 attr_name,
335 masked_value,
336 )
337 raise
338 else:
339 setattr(parent, attr_name, rendered_content)
340
341 def _iter_all_mapped_downstreams(self) -> Iterator[MappedOperator | MappedTaskGroup]:
342 """
343 Return mapped nodes that are direct dependencies of the current task.
344
345 For now, this walks the entire Dag to find mapped nodes that has this
346 current task as an upstream. We cannot use ``downstream_list`` since it
347 only contains operators, not task groups. In the future, we should
348 provide a way to record a Dag node's all downstream nodes instead.
349
350 Note that this does not guarantee the returned tasks actually use the
351 current task for task mapping, but only checks those task are mapped
352 operators, and are downstreams of the current task.
353
354 To get a list of tasks that uses the current task for task mapping, use
355 :meth:`iter_mapped_dependants` instead.
356 """
357 from airflow.sdk.definitions.mappedoperator import MappedOperator
358 from airflow.sdk.definitions.taskgroup import MappedTaskGroup, TaskGroup
359
360 def _walk_group(group: TaskGroup) -> Iterable[tuple[str, DAGNode]]:
361 """
362 Recursively walk children in a task group.
363
364 This yields all direct children (including both tasks and task
365 groups), and all children of any task groups.
366 """
367 for key, child in group.children.items():
368 yield key, child
369 if isinstance(child, TaskGroup):
370 yield from _walk_group(child)
371
372 dag = self.get_dag()
373 if not dag:
374 raise RuntimeError("Cannot check for mapped dependants when not attached to a Dag")
375 for key, child in _walk_group(dag.task_group):
376 if key == self.node_id:
377 continue
378 if not isinstance(child, (MappedOperator, MappedTaskGroup)):
379 continue
380 if self.node_id in child.upstream_task_ids:
381 yield child
382
383 def iter_mapped_dependants(self) -> Iterator[MappedOperator | MappedTaskGroup]:
384 """
385 Return mapped nodes that depend on the current task the expansion.
386
387 For now, this walks the entire Dag to find mapped nodes that has this
388 current task as an upstream. We cannot use ``downstream_list`` since it
389 only contains operators, not task groups. In the future, we should
390 provide a way to record a Dag node's all downstream nodes instead.
391 """
392 return (
393 downstream
394 for downstream in self._iter_all_mapped_downstreams()
395 if any(p.node_id == self.node_id for p in downstream.iter_mapped_dependencies())
396 )
397
398 def iter_mapped_task_groups(self) -> Iterator[MappedTaskGroup]:
399 """
400 Return mapped task groups this task belongs to.
401
402 Groups are returned from the innermost to the outmost.
403
404 :meta private:
405 """
406 if (group := self.task_group) is None:
407 return
408 yield from group.iter_mapped_task_groups()
409
410 def get_closest_mapped_task_group(self) -> MappedTaskGroup | None:
411 """
412 Get the mapped task group "closest" to this task in the Dag.
413
414 :meta private:
415 """
416 return next(self.iter_mapped_task_groups(), None)
417
418 def get_needs_expansion(self) -> bool:
419 """
420 Return true if the task is MappedOperator or is in a mapped task group.
421
422 :meta private:
423 """
424 if self._needs_expansion is None:
425 if self.get_closest_mapped_task_group() is not None:
426 self._needs_expansion = True
427 else:
428 self._needs_expansion = False
429 return self._needs_expansion