1#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18from __future__ import annotations
19
20import datetime
21import logging
22from abc import abstractmethod
23from collections.abc import (
24 Callable,
25 Collection,
26 Iterable,
27 Iterator,
28)
29from typing import TYPE_CHECKING, Any, ClassVar, TypeAlias
30
31from airflow.sdk import TriggerRule, WeightRule
32from airflow.sdk.configuration import conf
33from airflow.sdk.definitions._internal.mixins import DependencyMixin
34from airflow.sdk.definitions._internal.node import DAGNode
35from airflow.sdk.definitions._internal.setup_teardown import SetupTeardownContext
36from airflow.sdk.definitions._internal.templater import Templater
37from airflow.sdk.definitions.context import Context
38
39if TYPE_CHECKING:
40 import jinja2
41
42 from airflow.sdk.bases.operator import BaseOperator
43 from airflow.sdk.bases.operatorlink import BaseOperatorLink
44 from airflow.sdk.definitions.dag import DAG
45 from airflow.sdk.definitions.mappedoperator import MappedOperator
46 from airflow.sdk.definitions.taskgroup import MappedTaskGroup
47
48TaskStateChangeCallback = Callable[[Context], None]
49TaskStateChangeCallbackAttrType: TypeAlias = TaskStateChangeCallback | list[TaskStateChangeCallback] | None
50
51DEFAULT_OWNER: str = conf.get_mandatory_value("operators", "default_owner")
52DEFAULT_POOL_SLOTS: int = 1
53DEFAULT_POOL_NAME = "default_pool"
54DEFAULT_PRIORITY_WEIGHT: int = 1
55# Databases do not support arbitrary precision integers, so we need to limit the range of priority weights.
56# postgres: -2147483648 to +2147483647 (see https://www.postgresql.org/docs/current/datatype-numeric.html)
57# mysql: -2147483648 to +2147483647 (see https://dev.mysql.com/doc/refman/8.4/en/integer-types.html)
58# sqlite: -9223372036854775808 to +9223372036854775807 (see https://sqlite.org/datatype3.html)
59MINIMUM_PRIORITY_WEIGHT: int = -2147483648
60MAXIMUM_PRIORITY_WEIGHT: int = 2147483647
61DEFAULT_EXECUTOR: str | None = None
62DEFAULT_QUEUE: str = conf.get_mandatory_value("operators", "default_queue")
63DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST: bool = False
64DEFAULT_WAIT_FOR_PAST_DEPENDS_BEFORE_SKIPPING: bool = False
65DEFAULT_RETRIES: int = conf.getint("core", "default_task_retries", fallback=0)
66DEFAULT_RETRY_DELAY: datetime.timedelta = datetime.timedelta(
67 seconds=conf.getint("core", "default_task_retry_delay", fallback=300)
68)
69MAX_RETRY_DELAY: int = conf.getint("core", "max_task_retry_delay", fallback=24 * 60 * 60)
70
71# TODO: Task-SDK -- these defaults should be overridable from the Airflow config
72DEFAULT_TRIGGER_RULE: TriggerRule = TriggerRule.ALL_SUCCESS
73DEFAULT_WEIGHT_RULE: WeightRule = WeightRule(
74 conf.get("core", "default_task_weight_rule", fallback=WeightRule.DOWNSTREAM)
75)
76DEFAULT_TASK_EXECUTION_TIMEOUT: datetime.timedelta | None = conf.gettimedelta(
77 "core", "default_task_execution_timeout"
78)
79DEFAULT_EMAIL_ON_FAILURE: bool = conf.getboolean("email", "default_email_on_failure", fallback=True)
80DEFAULT_EMAIL_ON_RETRY: bool = conf.getboolean("email", "default_email_on_retry", fallback=True)
81log = logging.getLogger(__name__)
82
83
84class AbstractOperator(Templater, DAGNode):
85 """
86 Common implementation for operators, including unmapped and mapped.
87
88 This base class is more about sharing implementations, not defining a common
89 interface. Unfortunately it's difficult to use this as the common base class
90 for typing due to BaseOperator carrying too much historical baggage.
91
92 The union type ``from airflow.models.operator import Operator`` is easier
93 to use for typing purposes.
94
95 :meta private:
96 """
97
98 operator_class: type[BaseOperator] | dict[str, Any]
99
100 priority_weight: int
101
102 # Defines the operator level extra links.
103 operator_extra_links: Collection[BaseOperatorLink] = ()
104
105 owner: str
106 task_id: str
107
108 outlets: list
109 inlets: list
110
111 trigger_rule: TriggerRule
112 _needs_expansion: bool | None = None
113 _on_failure_fail_dagrun = False
114 is_setup: bool = False
115 is_teardown: bool = False
116
117 HIDE_ATTRS_FROM_UI: ClassVar[frozenset[str]] = frozenset(
118 (
119 "log",
120 "dag", # We show dag_id, don't need to show this too
121 "node_id", # Duplicates task_id
122 "task_group", # Doesn't have a useful repr, no point showing in UI
123 "inherits_from_empty_operator", # impl detail
124 "inherits_from_skipmixin", # impl detail
125 # Decide whether to start task execution from triggerer
126 "start_trigger_args",
127 "start_from_trigger",
128 # For compatibility with TG, for operators these are just the current task, no point showing
129 "roots",
130 "leaves",
131 # These lists are already shown via *_task_ids
132 "upstream_list",
133 "downstream_list",
134 # Not useful, implementation detail, already shown elsewhere
135 "global_operator_extra_link_dict",
136 "operator_extra_link_dict",
137 )
138 )
139
140 @property
141 def task_type(self) -> str:
142 raise NotImplementedError()
143
144 @property
145 def operator_name(self) -> str:
146 raise NotImplementedError()
147
148 _is_sensor: bool = False
149 _is_mapped: bool = False
150 _can_skip_downstream: bool = False
151
152 @property
153 def dag_id(self) -> str:
154 """Returns dag id if it has one or an adhoc + owner."""
155 dag = self.get_dag()
156 if dag:
157 return dag.dag_id
158 return f"adhoc_{self.owner}"
159
160 @property
161 def node_id(self) -> str:
162 return self.task_id
163
164 @property
165 @abstractmethod
166 def task_display_name(self) -> str: ...
167
168 @property
169 def is_mapped(self):
170 return self._is_mapped
171
172 @property
173 def label(self) -> str | None:
174 if self.task_display_name and self.task_display_name != self.task_id:
175 return self.task_display_name
176 # Prefix handling if no display is given is cloned from taskmixin for compatibility
177 tg = self.task_group
178 if tg and tg.node_id and tg.prefix_group_id:
179 # "task_group_id.task_id" -> "task_id"
180 return self.task_id[len(tg.node_id) + 1 :]
181 return self.task_id
182
183 @property
184 def on_failure_fail_dagrun(self):
185 """
186 Whether the operator should fail the dagrun on failure.
187
188 :meta private:
189 """
190 return self._on_failure_fail_dagrun
191
192 @on_failure_fail_dagrun.setter
193 def on_failure_fail_dagrun(self, value):
194 """
195 Setter for on_failure_fail_dagrun property.
196
197 :meta private:
198 """
199 if value is True and self.is_teardown is not True:
200 raise ValueError(
201 f"Cannot set task on_failure_fail_dagrun for "
202 f"'{self.task_id}' because it is not a teardown task."
203 )
204 self._on_failure_fail_dagrun = value
205
206 @property
207 def inherits_from_empty_operator(self):
208 """Used to determine if an Operator is inherited from EmptyOperator."""
209 # This looks like `isinstance(self, EmptyOperator) would work, but this also
210 # needs to cope when `self` is a Serialized instance of a EmptyOperator or one
211 # of its subclasses (which don't inherit from anything but BaseOperator).
212 return getattr(self, "_is_empty", False)
213
214 @property
215 def inherits_from_skipmixin(self):
216 """Used to determine if an Operator is inherited from SkipMixin or its subclasses (e.g., BranchMixin)."""
217 return getattr(self, "_can_skip_downstream", False)
218
219 def as_setup(self):
220 self.is_setup = True
221 return self
222
223 def as_teardown(
224 self,
225 *,
226 setups: BaseOperator | Iterable[BaseOperator] | None = None,
227 on_failure_fail_dagrun: bool | None = None,
228 ):
229 self.is_teardown = True
230 self.trigger_rule = TriggerRule.ALL_DONE_SETUP_SUCCESS
231 if on_failure_fail_dagrun is not None:
232 self.on_failure_fail_dagrun = on_failure_fail_dagrun
233 if setups is not None:
234 setups = [setups] if isinstance(setups, DependencyMixin) else setups
235 for s in setups:
236 s.is_setup = True
237 s >> self
238 return self
239
240 def __enter__(self):
241 if not self.is_setup and not self.is_teardown:
242 raise RuntimeError("Only setup/teardown tasks can be used as context managers.")
243 SetupTeardownContext.push_setup_teardown_task(self)
244 return SetupTeardownContext
245
246 def __exit__(self, exc_type, exc_val, exc_tb):
247 SetupTeardownContext.set_work_task_roots_and_leaves()
248
249 # TODO: Task-SDK -- Should the following methods removed?
250 # get_template_env
251 # _render
252 def get_template_env(self, dag: DAG | None = None) -> jinja2.Environment:
253 """Get the template environment for rendering templates."""
254 if dag is None:
255 dag = self.get_dag()
256 return super().get_template_env(dag=dag)
257
258 def _render(self, template, context, dag: DAG | None = None):
259 if dag is None:
260 dag = self.get_dag()
261 return super()._render(template, context, dag=dag)
262
263 def _do_render_template_fields(
264 self,
265 parent: Any,
266 template_fields: Iterable[str],
267 context: Context,
268 jinja_env: jinja2.Environment,
269 seen_oids: set[int],
270 ) -> None:
271 """Override the base to use custom error logging."""
272 for attr_name in template_fields:
273 try:
274 value = getattr(parent, attr_name)
275 except AttributeError:
276 raise AttributeError(
277 f"{attr_name!r} is configured as a template field "
278 f"but {parent.task_type} does not have this attribute."
279 )
280 try:
281 if not value:
282 continue
283 except Exception:
284 # This may happen if the templated field points to a class which does not support `__bool__`,
285 # such as Pandas DataFrames:
286 # https://github.com/pandas-dev/pandas/blob/9135c3aaf12d26f857fcc787a5b64d521c51e379/pandas/core/generic.py#L1465
287 log.info(
288 "Unable to check if the value of type '%s' is False for task '%s', field '%s'.",
289 type(value).__name__,
290 self.task_id,
291 attr_name,
292 )
293 # We may still want to render custom classes which do not support __bool__
294 pass
295
296 try:
297 if callable(value):
298 rendered_content = value(context=context, jinja_env=jinja_env)
299 else:
300 rendered_content = self.render_template(value, context, jinja_env, seen_oids)
301 except Exception:
302 # Mask sensitive values in the template before logging
303 from airflow.sdk._shared.secrets_masker import redact
304
305 masked_value = redact(value)
306 log.exception(
307 "Exception rendering Jinja template for task '%s', field '%s'. Template: %r",
308 self.task_id,
309 attr_name,
310 masked_value,
311 )
312 raise
313 else:
314 setattr(parent, attr_name, rendered_content)
315
316 def _iter_all_mapped_downstreams(self) -> Iterator[MappedOperator | MappedTaskGroup]:
317 """
318 Return mapped nodes that are direct dependencies of the current task.
319
320 For now, this walks the entire Dag to find mapped nodes that has this
321 current task as an upstream. We cannot use ``downstream_list`` since it
322 only contains operators, not task groups. In the future, we should
323 provide a way to record a Dag node's all downstream nodes instead.
324
325 Note that this does not guarantee the returned tasks actually use the
326 current task for task mapping, but only checks those task are mapped
327 operators, and are downstreams of the current task.
328
329 To get a list of tasks that uses the current task for task mapping, use
330 :meth:`iter_mapped_dependants` instead.
331 """
332 from airflow.sdk.definitions.mappedoperator import MappedOperator
333 from airflow.sdk.definitions.taskgroup import MappedTaskGroup, TaskGroup
334
335 def _walk_group(group: TaskGroup) -> Iterable[tuple[str, DAGNode]]:
336 """
337 Recursively walk children in a task group.
338
339 This yields all direct children (including both tasks and task
340 groups), and all children of any task groups.
341 """
342 for key, child in group.children.items():
343 yield key, child
344 if isinstance(child, TaskGroup):
345 yield from _walk_group(child)
346
347 dag = self.get_dag()
348 if not dag:
349 raise RuntimeError("Cannot check for mapped dependants when not attached to a Dag")
350 for key, child in _walk_group(dag.task_group):
351 if key == self.node_id:
352 continue
353 if not isinstance(child, (MappedOperator, MappedTaskGroup)):
354 continue
355 if self.node_id in child.upstream_task_ids:
356 yield child
357
358 def iter_mapped_dependants(self) -> Iterator[MappedOperator | MappedTaskGroup]:
359 """
360 Return mapped nodes that depend on the current task the expansion.
361
362 For now, this walks the entire Dag to find mapped nodes that has this
363 current task as an upstream. We cannot use ``downstream_list`` since it
364 only contains operators, not task groups. In the future, we should
365 provide a way to record a Dag node's all downstream nodes instead.
366 """
367 return (
368 downstream
369 for downstream in self._iter_all_mapped_downstreams()
370 if any(p.node_id == self.node_id for p in downstream.iter_mapped_dependencies())
371 )
372
373 def iter_mapped_task_groups(self) -> Iterator[MappedTaskGroup]:
374 """
375 Return mapped task groups this task belongs to.
376
377 Groups are returned from the innermost to the outmost.
378
379 :meta private:
380 """
381 if (group := self.task_group) is None:
382 return
383 yield from group.iter_mapped_task_groups()
384
385 def get_closest_mapped_task_group(self) -> MappedTaskGroup | None:
386 """
387 Get the mapped task group "closest" to this task in the Dag.
388
389 :meta private:
390 """
391 return next(self.iter_mapped_task_groups(), None)
392
393 def get_needs_expansion(self) -> bool:
394 """
395 Return true if the task is MappedOperator or is in a mapped task group.
396
397 :meta private:
398 """
399 if self._needs_expansion is None:
400 if self.get_closest_mapped_task_group() is not None:
401 self._needs_expansion = True
402 else:
403 self._needs_expansion = False
404 return self._needs_expansion