1#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18from __future__ import annotations
19
20import datetime
21import logging
22from abc import abstractmethod
23from collections.abc import (
24 Callable,
25 Collection,
26 Iterable,
27 Iterator,
28)
29from typing import TYPE_CHECKING, Any, ClassVar, TypeAlias
30
31from airflow.sdk import TriggerRule, WeightRule
32from airflow.sdk.configuration import conf
33from airflow.sdk.definitions._internal.mixins import DependencyMixin
34from airflow.sdk.definitions._internal.node import DAGNode
35from airflow.sdk.definitions._internal.setup_teardown import SetupTeardownContext
36from airflow.sdk.definitions._internal.templater import Templater
37from airflow.sdk.definitions.context import Context
38
39if TYPE_CHECKING:
40 import jinja2
41
42 from airflow.sdk.bases.operator import BaseOperator
43 from airflow.sdk.bases.operatorlink import BaseOperatorLink
44 from airflow.sdk.definitions.dag import DAG
45 from airflow.sdk.definitions.mappedoperator import MappedOperator
46 from airflow.sdk.definitions.taskgroup import MappedTaskGroup
47
48TaskStateChangeCallback = Callable[[Context], None]
49TaskStateChangeCallbackAttrType: TypeAlias = TaskStateChangeCallback | list[TaskStateChangeCallback] | None
50
51DEFAULT_OWNER: str = conf.get_mandatory_value("operators", "default_owner")
52DEFAULT_POOL_SLOTS: int = 1
53DEFAULT_POOL_NAME = "default_pool"
54DEFAULT_PRIORITY_WEIGHT: int = 1
55# Databases do not support arbitrary precision integers, so we need to limit the range of priority weights.
56# postgres: -2147483648 to +2147483647 (see https://www.postgresql.org/docs/current/datatype-numeric.html)
57# mysql: -2147483648 to +2147483647 (see https://dev.mysql.com/doc/refman/8.4/en/integer-types.html)
58# sqlite: -9223372036854775808 to +9223372036854775807 (see https://sqlite.org/datatype3.html)
59MINIMUM_PRIORITY_WEIGHT: int = -2147483648
60MAXIMUM_PRIORITY_WEIGHT: int = 2147483647
61DEFAULT_EXECUTOR: str | None = None
62DEFAULT_QUEUE: str = conf.get_mandatory_value("operators", "default_queue")
63DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST: bool = False
64DEFAULT_WAIT_FOR_PAST_DEPENDS_BEFORE_SKIPPING: bool = False
65DEFAULT_RETRIES: int = conf.getint("core", "default_task_retries", fallback=0)
66DEFAULT_RETRY_DELAY: datetime.timedelta = datetime.timedelta(
67 seconds=conf.getint("core", "default_task_retry_delay", fallback=300)
68)
69MAX_RETRY_DELAY: int = conf.getint("core", "max_task_retry_delay", fallback=24 * 60 * 60)
70
71# TODO: Task-SDK -- these defaults should be overridable from the Airflow config
72DEFAULT_TRIGGER_RULE: TriggerRule = TriggerRule.ALL_SUCCESS
73DEFAULT_WEIGHT_RULE: WeightRule = WeightRule(
74 conf.get("core", "default_task_weight_rule", fallback=WeightRule.DOWNSTREAM)
75)
76DEFAULT_TASK_EXECUTION_TIMEOUT: datetime.timedelta | None = conf.gettimedelta(
77 "core", "default_task_execution_timeout"
78)
79DEFAULT_EMAIL_ON_FAILURE: bool = conf.getboolean("email", "default_email_on_failure", fallback=True)
80DEFAULT_EMAIL_ON_RETRY: bool = conf.getboolean("email", "default_email_on_retry", fallback=True)
81log = logging.getLogger(__name__)
82
83
84class AbstractOperator(Templater, DAGNode):
85 """
86 Common implementation for operators, including unmapped and mapped.
87
88 This base class is more about sharing implementations, not defining a common
89 interface. Unfortunately it's difficult to use this as the common base class
90 for typing due to BaseOperator carrying too much historical baggage.
91
92 The union type ``from airflow.models.operator import Operator`` is easier
93 to use for typing purposes.
94
95 :meta private:
96 """
97
98 operator_class: type[BaseOperator] | dict[str, Any]
99
100 priority_weight: int
101
102 # Defines the operator level extra links.
103 operator_extra_links: Collection[BaseOperatorLink] = ()
104
105 owner: str
106 task_id: str
107
108 outlets: list
109 inlets: list
110
111 trigger_rule: TriggerRule
112 _needs_expansion: bool | None = None
113 _on_failure_fail_dagrun = False
114 is_setup: bool = False
115 is_teardown: bool = False
116
117 HIDE_ATTRS_FROM_UI: ClassVar[frozenset[str]] = frozenset(
118 (
119 "log",
120 "dag", # We show dag_id, don't need to show this too
121 "node_id", # Duplicates task_id
122 "task_group", # Doesn't have a useful repr, no point showing in UI
123 "inherits_from_empty_operator", # impl detail
124 "inherits_from_skipmixin", # impl detail
125 # Decide whether to start task execution from triggerer
126 "start_trigger_args",
127 "start_from_trigger",
128 # For compatibility with TG, for operators these are just the current task, no point showing
129 "roots",
130 "leaves",
131 # These lists are already shown via *_task_ids
132 "upstream_list",
133 "downstream_list",
134 # Not useful, implementation detail, already shown elsewhere
135 "global_operator_extra_link_dict",
136 "operator_extra_link_dict",
137 )
138 )
139
140 @property
141 def is_async(self) -> bool:
142 return False
143
144 @property
145 def task_type(self) -> str:
146 raise NotImplementedError()
147
148 @property
149 def operator_name(self) -> str:
150 raise NotImplementedError()
151
152 _is_sensor: bool = False
153 _is_mapped: bool = False
154 _can_skip_downstream: bool = False
155
156 @property
157 def dag_id(self) -> str:
158 """Returns dag id if it has one or an adhoc + owner."""
159 dag = self.get_dag()
160 if dag:
161 return dag.dag_id
162 return f"adhoc_{self.owner}"
163
164 @property
165 def node_id(self) -> str:
166 return self.task_id
167
168 @property
169 @abstractmethod
170 def task_display_name(self) -> str: ...
171
172 @property
173 def is_mapped(self):
174 return self._is_mapped
175
176 @property
177 def label(self) -> str | None:
178 if self.task_display_name and self.task_display_name != self.task_id:
179 return self.task_display_name
180 # Prefix handling if no display is given is cloned from taskmixin for compatibility
181 tg = self.task_group
182 if tg and tg.node_id and tg.prefix_group_id:
183 # "task_group_id.task_id" -> "task_id"
184 return self.task_id[len(tg.node_id) + 1 :]
185 return self.task_id
186
187 @property
188 def on_failure_fail_dagrun(self):
189 """
190 Whether the operator should fail the dagrun on failure.
191
192 :meta private:
193 """
194 return self._on_failure_fail_dagrun
195
196 @on_failure_fail_dagrun.setter
197 def on_failure_fail_dagrun(self, value):
198 """
199 Setter for on_failure_fail_dagrun property.
200
201 :meta private:
202 """
203 if value is True and self.is_teardown is not True:
204 raise ValueError(
205 f"Cannot set task on_failure_fail_dagrun for "
206 f"'{self.task_id}' because it is not a teardown task."
207 )
208 self._on_failure_fail_dagrun = value
209
210 @property
211 def inherits_from_empty_operator(self):
212 """Used to determine if an Operator is inherited from EmptyOperator."""
213 # This looks like `isinstance(self, EmptyOperator) would work, but this also
214 # needs to cope when `self` is a Serialized instance of a EmptyOperator or one
215 # of its subclasses (which don't inherit from anything but BaseOperator).
216 return getattr(self, "_is_empty", False)
217
218 @property
219 def inherits_from_skipmixin(self):
220 """Used to determine if an Operator is inherited from SkipMixin or its subclasses (e.g., BranchMixin)."""
221 return getattr(self, "_can_skip_downstream", False)
222
223 def as_setup(self):
224 self.is_setup = True
225 return self
226
227 def as_teardown(
228 self,
229 *,
230 setups: BaseOperator | Iterable[BaseOperator] | None = None,
231 on_failure_fail_dagrun: bool | None = None,
232 ):
233 self.is_teardown = True
234 self.trigger_rule = TriggerRule.ALL_DONE_SETUP_SUCCESS
235 if on_failure_fail_dagrun is not None:
236 self.on_failure_fail_dagrun = on_failure_fail_dagrun
237 if setups is not None:
238 setups = [setups] if isinstance(setups, DependencyMixin) else setups
239 for s in setups:
240 s.is_setup = True
241 s >> self
242 return self
243
244 def __enter__(self):
245 if not self.is_setup and not self.is_teardown:
246 raise RuntimeError("Only setup/teardown tasks can be used as context managers.")
247 SetupTeardownContext.push_setup_teardown_task(self)
248 return SetupTeardownContext
249
250 def __exit__(self, exc_type, exc_val, exc_tb):
251 SetupTeardownContext.set_work_task_roots_and_leaves()
252
253 # TODO: Task-SDK -- Should the following methods removed?
254 # get_template_env
255 # _render
256 def get_template_env(self, dag: DAG | None = None) -> jinja2.Environment:
257 """Get the template environment for rendering templates."""
258 if dag is None:
259 dag = self.get_dag()
260 return super().get_template_env(dag=dag)
261
262 def _render(self, template, context, dag: DAG | None = None):
263 if dag is None:
264 dag = self.get_dag()
265 return super()._render(template, context, dag=dag)
266
267 def _do_render_template_fields(
268 self,
269 parent: Any,
270 template_fields: Iterable[str],
271 context: Context,
272 jinja_env: jinja2.Environment,
273 seen_oids: set[int],
274 ) -> None:
275 """Override the base to use custom error logging."""
276 for attr_name in template_fields:
277 try:
278 value = getattr(parent, attr_name)
279 except AttributeError:
280 raise AttributeError(
281 f"{attr_name!r} is configured as a template field "
282 f"but {parent.task_type} does not have this attribute."
283 )
284 try:
285 if not value:
286 continue
287 except Exception:
288 # This may happen if the templated field points to a class which does not support `__bool__`,
289 # such as Pandas DataFrames:
290 # https://github.com/pandas-dev/pandas/blob/9135c3aaf12d26f857fcc787a5b64d521c51e379/pandas/core/generic.py#L1465
291 log.info(
292 "Unable to check if the value of type '%s' is False for task '%s', field '%s'.",
293 type(value).__name__,
294 self.task_id,
295 attr_name,
296 )
297 # We may still want to render custom classes which do not support __bool__
298 pass
299
300 try:
301 if callable(value):
302 rendered_content = value(context=context, jinja_env=jinja_env)
303 else:
304 rendered_content = self.render_template(value, context, jinja_env, seen_oids)
305 except Exception:
306 # Mask sensitive values in the template before logging
307 from airflow.sdk._shared.secrets_masker import redact
308
309 masked_value = redact(value)
310 log.exception(
311 "Exception rendering Jinja template for task '%s', field '%s'. Template: %r",
312 self.task_id,
313 attr_name,
314 masked_value,
315 )
316 raise
317 else:
318 setattr(parent, attr_name, rendered_content)
319
320 def _iter_all_mapped_downstreams(self) -> Iterator[MappedOperator | MappedTaskGroup]:
321 """
322 Return mapped nodes that are direct dependencies of the current task.
323
324 For now, this walks the entire Dag to find mapped nodes that has this
325 current task as an upstream. We cannot use ``downstream_list`` since it
326 only contains operators, not task groups. In the future, we should
327 provide a way to record a Dag node's all downstream nodes instead.
328
329 Note that this does not guarantee the returned tasks actually use the
330 current task for task mapping, but only checks those task are mapped
331 operators, and are downstreams of the current task.
332
333 To get a list of tasks that uses the current task for task mapping, use
334 :meth:`iter_mapped_dependants` instead.
335 """
336 from airflow.sdk.definitions.mappedoperator import MappedOperator
337 from airflow.sdk.definitions.taskgroup import MappedTaskGroup, TaskGroup
338
339 def _walk_group(group: TaskGroup) -> Iterable[tuple[str, DAGNode]]:
340 """
341 Recursively walk children in a task group.
342
343 This yields all direct children (including both tasks and task
344 groups), and all children of any task groups.
345 """
346 for key, child in group.children.items():
347 yield key, child
348 if isinstance(child, TaskGroup):
349 yield from _walk_group(child)
350
351 dag = self.get_dag()
352 if not dag:
353 raise RuntimeError("Cannot check for mapped dependants when not attached to a Dag")
354 for key, child in _walk_group(dag.task_group):
355 if key == self.node_id:
356 continue
357 if not isinstance(child, (MappedOperator, MappedTaskGroup)):
358 continue
359 if self.node_id in child.upstream_task_ids:
360 yield child
361
362 def iter_mapped_dependants(self) -> Iterator[MappedOperator | MappedTaskGroup]:
363 """
364 Return mapped nodes that depend on the current task the expansion.
365
366 For now, this walks the entire Dag to find mapped nodes that has this
367 current task as an upstream. We cannot use ``downstream_list`` since it
368 only contains operators, not task groups. In the future, we should
369 provide a way to record a Dag node's all downstream nodes instead.
370 """
371 return (
372 downstream
373 for downstream in self._iter_all_mapped_downstreams()
374 if any(p.node_id == self.node_id for p in downstream.iter_mapped_dependencies())
375 )
376
377 def iter_mapped_task_groups(self) -> Iterator[MappedTaskGroup]:
378 """
379 Return mapped task groups this task belongs to.
380
381 Groups are returned from the innermost to the outmost.
382
383 :meta private:
384 """
385 if (group := self.task_group) is None:
386 return
387 yield from group.iter_mapped_task_groups()
388
389 def get_closest_mapped_task_group(self) -> MappedTaskGroup | None:
390 """
391 Get the mapped task group "closest" to this task in the Dag.
392
393 :meta private:
394 """
395 return next(self.iter_mapped_task_groups(), None)
396
397 def get_needs_expansion(self) -> bool:
398 """
399 Return true if the task is MappedOperator or is in a mapped task group.
400
401 :meta private:
402 """
403 if self._needs_expansion is None:
404 if self.get_closest_mapped_task_group() is not None:
405 self._needs_expansion = True
406 else:
407 self._needs_expansion = False
408 return self._needs_expansion