1#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18from __future__ import annotations
19
20import datetime
21import logging
22from abc import abstractmethod
23from collections.abc import (
24 Callable,
25 Collection,
26 Iterable,
27 Iterator,
28)
29from typing import TYPE_CHECKING, Any, ClassVar, TypeAlias
30
31import methodtools
32
33from airflow.sdk import TriggerRule, WeightRule
34from airflow.sdk.configuration import conf
35from airflow.sdk.definitions._internal.mixins import DependencyMixin
36from airflow.sdk.definitions._internal.node import DAGNode
37from airflow.sdk.definitions._internal.setup_teardown import SetupTeardownContext
38from airflow.sdk.definitions._internal.templater import Templater
39from airflow.sdk.definitions.context import Context
40
41if TYPE_CHECKING:
42 import jinja2
43
44 from airflow.sdk.bases.operator import BaseOperator
45 from airflow.sdk.bases.operatorlink import BaseOperatorLink
46 from airflow.sdk.definitions.dag import DAG
47 from airflow.sdk.definitions.mappedoperator import MappedOperator
48 from airflow.sdk.definitions.taskgroup import MappedTaskGroup
49
50TaskStateChangeCallback = Callable[[Context], None]
51TaskStateChangeCallbackAttrType: TypeAlias = TaskStateChangeCallback | list[TaskStateChangeCallback] | None
52
53DEFAULT_OWNER: str = conf.get_mandatory_value("operators", "default_owner")
54DEFAULT_POOL_SLOTS: int = 1
55DEFAULT_POOL_NAME = "default_pool"
56DEFAULT_PRIORITY_WEIGHT: int = 1
57# Databases do not support arbitrary precision integers, so we need to limit the range of priority weights.
58# postgres: -2147483648 to +2147483647 (see https://www.postgresql.org/docs/current/datatype-numeric.html)
59# mysql: -2147483648 to +2147483647 (see https://dev.mysql.com/doc/refman/8.4/en/integer-types.html)
60# sqlite: -9223372036854775808 to +9223372036854775807 (see https://sqlite.org/datatype3.html)
61MINIMUM_PRIORITY_WEIGHT: int = -2147483648
62MAXIMUM_PRIORITY_WEIGHT: int = 2147483647
63DEFAULT_EXECUTOR: str | None = None
64DEFAULT_QUEUE: str = conf.get_mandatory_value("operators", "default_queue")
65DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST: bool = False
66DEFAULT_WAIT_FOR_PAST_DEPENDS_BEFORE_SKIPPING: bool = False
67DEFAULT_RETRIES: int = conf.getint("core", "default_task_retries", fallback=0)
68DEFAULT_RETRY_DELAY: datetime.timedelta = datetime.timedelta(
69 seconds=conf.getint("core", "default_task_retry_delay", fallback=300)
70)
71DEFAULT_RETRY_DELAY_MULTIPLIER: float = 2.0
72MAX_RETRY_DELAY: int = conf.getint("core", "max_task_retry_delay", fallback=24 * 60 * 60)
73
74# TODO: Task-SDK -- these defaults should be overridable from the Airflow config
75DEFAULT_TRIGGER_RULE: TriggerRule = TriggerRule.ALL_SUCCESS
76DEFAULT_WEIGHT_RULE: WeightRule = WeightRule(
77 conf.get("core", "default_task_weight_rule", fallback=WeightRule.DOWNSTREAM)
78)
79DEFAULT_TASK_EXECUTION_TIMEOUT: datetime.timedelta | None = conf.gettimedelta(
80 "core", "default_task_execution_timeout"
81)
82
83log = logging.getLogger(__name__)
84
85
86class NotMapped(Exception):
87 """Raise if a task is neither mapped nor has any parent mapped groups."""
88
89
90class AbstractOperator(Templater, DAGNode):
91 """
92 Common implementation for operators, including unmapped and mapped.
93
94 This base class is more about sharing implementations, not defining a common
95 interface. Unfortunately it's difficult to use this as the common base class
96 for typing due to BaseOperator carrying too much historical baggage.
97
98 The union type ``from airflow.models.operator import Operator`` is easier
99 to use for typing purposes.
100
101 :meta private:
102 """
103
104 operator_class: type[BaseOperator] | dict[str, Any]
105
106 priority_weight: int
107
108 # Defines the operator level extra links.
109 operator_extra_links: Collection[BaseOperatorLink] = ()
110
111 owner: str
112 task_id: str
113
114 outlets: list
115 inlets: list
116
117 trigger_rule: TriggerRule
118 _needs_expansion: bool | None = None
119 _on_failure_fail_dagrun = False
120 is_setup: bool = False
121 is_teardown: bool = False
122
123 HIDE_ATTRS_FROM_UI: ClassVar[frozenset[str]] = frozenset(
124 (
125 "log",
126 "dag", # We show dag_id, don't need to show this too
127 "node_id", # Duplicates task_id
128 "task_group", # Doesn't have a useful repr, no point showing in UI
129 "inherits_from_empty_operator", # impl detail
130 "inherits_from_skipmixin", # impl detail
131 # Decide whether to start task execution from triggerer
132 "start_trigger_args",
133 "start_from_trigger",
134 # For compatibility with TG, for operators these are just the current task, no point showing
135 "roots",
136 "leaves",
137 # These lists are already shown via *_task_ids
138 "upstream_list",
139 "downstream_list",
140 # Not useful, implementation detail, already shown elsewhere
141 "global_operator_extra_link_dict",
142 "operator_extra_link_dict",
143 )
144 )
145
146 @property
147 def task_type(self) -> str:
148 raise NotImplementedError()
149
150 @property
151 def operator_name(self) -> str:
152 raise NotImplementedError()
153
154 _is_sensor: bool = False
155 _is_mapped: bool = False
156 _can_skip_downstream: bool = False
157
158 @property
159 def dag_id(self) -> str:
160 """Returns dag id if it has one or an adhoc + owner."""
161 dag = self.get_dag()
162 if dag:
163 return dag.dag_id
164 return f"adhoc_{self.owner}"
165
166 @property
167 def node_id(self) -> str:
168 return self.task_id
169
170 @property
171 @abstractmethod
172 def task_display_name(self) -> str: ...
173
174 @property
175 def is_mapped(self):
176 return self._is_mapped
177
178 @property
179 def label(self) -> str | None:
180 if self.task_display_name and self.task_display_name != self.task_id:
181 return self.task_display_name
182 # Prefix handling if no display is given is cloned from taskmixin for compatibility
183 tg = self.task_group
184 if tg and tg.node_id and tg.prefix_group_id:
185 # "task_group_id.task_id" -> "task_id"
186 return self.task_id[len(tg.node_id) + 1 :]
187 return self.task_id
188
189 @property
190 def on_failure_fail_dagrun(self):
191 """
192 Whether the operator should fail the dagrun on failure.
193
194 :meta private:
195 """
196 return self._on_failure_fail_dagrun
197
198 @on_failure_fail_dagrun.setter
199 def on_failure_fail_dagrun(self, value):
200 """
201 Setter for on_failure_fail_dagrun property.
202
203 :meta private:
204 """
205 if value is True and self.is_teardown is not True:
206 raise ValueError(
207 f"Cannot set task on_failure_fail_dagrun for "
208 f"'{self.task_id}' because it is not a teardown task."
209 )
210 self._on_failure_fail_dagrun = value
211
212 @property
213 def inherits_from_empty_operator(self):
214 """Used to determine if an Operator is inherited from EmptyOperator."""
215 # This looks like `isinstance(self, EmptyOperator) would work, but this also
216 # needs to cope when `self` is a Serialized instance of a EmptyOperator or one
217 # of its subclasses (which don't inherit from anything but BaseOperator).
218 return getattr(self, "_is_empty", False)
219
220 @property
221 def inherits_from_skipmixin(self):
222 """Used to determine if an Operator is inherited from SkipMixin or its subclasses (e.g., BranchMixin)."""
223 return getattr(self, "_can_skip_downstream", False)
224
225 def as_setup(self):
226 self.is_setup = True
227 return self
228
229 def as_teardown(
230 self,
231 *,
232 setups: BaseOperator | Iterable[BaseOperator] | None = None,
233 on_failure_fail_dagrun: bool | None = None,
234 ):
235 self.is_teardown = True
236 self.trigger_rule = TriggerRule.ALL_DONE_SETUP_SUCCESS
237 if on_failure_fail_dagrun is not None:
238 self.on_failure_fail_dagrun = on_failure_fail_dagrun
239 if setups is not None:
240 setups = [setups] if isinstance(setups, DependencyMixin) else setups
241 for s in setups:
242 s.is_setup = True
243 s >> self
244 return self
245
246 def __enter__(self):
247 if not self.is_setup and not self.is_teardown:
248 raise RuntimeError("Only setup/teardown tasks can be used as context managers.")
249 SetupTeardownContext.push_setup_teardown_task(self)
250 return SetupTeardownContext
251
252 def __exit__(self, exc_type, exc_val, exc_tb):
253 SetupTeardownContext.set_work_task_roots_and_leaves()
254
255 # TODO: Task-SDK -- Should the following methods removed?
256 # get_template_env
257 # _render
258 def get_template_env(self, dag: DAG | None = None) -> jinja2.Environment:
259 """Get the template environment for rendering templates."""
260 if dag is None:
261 dag = self.get_dag()
262 return super().get_template_env(dag=dag)
263
264 def _render(self, template, context, dag: DAG | None = None):
265 if dag is None:
266 dag = self.get_dag()
267 return super()._render(template, context, dag=dag)
268
269 def _do_render_template_fields(
270 self,
271 parent: Any,
272 template_fields: Iterable[str],
273 context: Context,
274 jinja_env: jinja2.Environment,
275 seen_oids: set[int],
276 ) -> None:
277 """Override the base to use custom error logging."""
278 for attr_name in template_fields:
279 try:
280 value = getattr(parent, attr_name)
281 except AttributeError:
282 raise AttributeError(
283 f"{attr_name!r} is configured as a template field "
284 f"but {parent.task_type} does not have this attribute."
285 )
286 try:
287 if not value:
288 continue
289 except Exception:
290 # This may happen if the templated field points to a class which does not support `__bool__`,
291 # such as Pandas DataFrames:
292 # https://github.com/pandas-dev/pandas/blob/9135c3aaf12d26f857fcc787a5b64d521c51e379/pandas/core/generic.py#L1465
293 log.info(
294 "Unable to check if the value of type '%s' is False for task '%s', field '%s'.",
295 type(value).__name__,
296 self.task_id,
297 attr_name,
298 )
299 # We may still want to render custom classes which do not support __bool__
300 pass
301
302 try:
303 if callable(value):
304 rendered_content = value(context=context, jinja_env=jinja_env)
305 else:
306 rendered_content = self.render_template(value, context, jinja_env, seen_oids)
307 except Exception:
308 # Mask sensitive values in the template before logging
309 from airflow.sdk._shared.secrets_masker import redact
310
311 masked_value = redact(value)
312 log.exception(
313 "Exception rendering Jinja template for task '%s', field '%s'. Template: %r",
314 self.task_id,
315 attr_name,
316 masked_value,
317 )
318 raise
319 else:
320 setattr(parent, attr_name, rendered_content)
321
322 def _iter_all_mapped_downstreams(self) -> Iterator[MappedOperator | MappedTaskGroup]:
323 """
324 Return mapped nodes that are direct dependencies of the current task.
325
326 For now, this walks the entire Dag to find mapped nodes that has this
327 current task as an upstream. We cannot use ``downstream_list`` since it
328 only contains operators, not task groups. In the future, we should
329 provide a way to record a Dag node's all downstream nodes instead.
330
331 Note that this does not guarantee the returned tasks actually use the
332 current task for task mapping, but only checks those task are mapped
333 operators, and are downstreams of the current task.
334
335 To get a list of tasks that uses the current task for task mapping, use
336 :meth:`iter_mapped_dependants` instead.
337 """
338 from airflow.sdk.definitions.mappedoperator import MappedOperator
339 from airflow.sdk.definitions.taskgroup import MappedTaskGroup, TaskGroup
340
341 def _walk_group(group: TaskGroup) -> Iterable[tuple[str, DAGNode]]:
342 """
343 Recursively walk children in a task group.
344
345 This yields all direct children (including both tasks and task
346 groups), and all children of any task groups.
347 """
348 for key, child in group.children.items():
349 yield key, child
350 if isinstance(child, TaskGroup):
351 yield from _walk_group(child)
352
353 dag = self.get_dag()
354 if not dag:
355 raise RuntimeError("Cannot check for mapped dependants when not attached to a Dag")
356 for key, child in _walk_group(dag.task_group):
357 if key == self.node_id:
358 continue
359 if not isinstance(child, (MappedOperator, MappedTaskGroup)):
360 continue
361 if self.node_id in child.upstream_task_ids:
362 yield child
363
364 def iter_mapped_dependants(self) -> Iterator[MappedOperator | MappedTaskGroup]:
365 """
366 Return mapped nodes that depend on the current task the expansion.
367
368 For now, this walks the entire Dag to find mapped nodes that has this
369 current task as an upstream. We cannot use ``downstream_list`` since it
370 only contains operators, not task groups. In the future, we should
371 provide a way to record a Dag node's all downstream nodes instead.
372 """
373 return (
374 downstream
375 for downstream in self._iter_all_mapped_downstreams()
376 if any(p.node_id == self.node_id for p in downstream.iter_mapped_dependencies())
377 )
378
379 def iter_mapped_task_groups(self) -> Iterator[MappedTaskGroup]:
380 """
381 Return mapped task groups this task belongs to.
382
383 Groups are returned from the innermost to the outmost.
384
385 :meta private:
386 """
387 if (group := self.task_group) is None:
388 return
389 yield from group.iter_mapped_task_groups()
390
391 def get_closest_mapped_task_group(self) -> MappedTaskGroup | None:
392 """
393 Get the mapped task group "closest" to this task in the Dag.
394
395 :meta private:
396 """
397 return next(self.iter_mapped_task_groups(), None)
398
399 def get_needs_expansion(self) -> bool:
400 """
401 Return true if the task is MappedOperator or is in a mapped task group.
402
403 :meta private:
404 """
405 if self._needs_expansion is None:
406 if self.get_closest_mapped_task_group() is not None:
407 self._needs_expansion = True
408 else:
409 self._needs_expansion = False
410 return self._needs_expansion
411
412 @methodtools.lru_cache(maxsize=None)
413 def get_parse_time_mapped_ti_count(self) -> int:
414 """
415 Return the number of mapped task instances that can be created on Dag run creation.
416
417 This only considers literal mapped arguments, and would return *None*
418 when any non-literal values are used for mapping.
419
420 :raise NotFullyPopulated: If non-literal mapped arguments are encountered.
421 :raise NotMapped: If the operator is neither mapped, nor has any parent
422 mapped task groups.
423 :return: Total number of mapped TIs this task should have.
424 """
425 group = self.get_closest_mapped_task_group()
426 if group is None:
427 raise NotMapped()
428 return group.get_parse_time_mapped_ti_count()