1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements. See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership. The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License. You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied. See the License for the
15# specific language governing permissions and limitations
16# under the License.
17from __future__ import annotations
18
19import inspect
20import itertools
21import re
22import textwrap
23import warnings
24from collections.abc import Callable, Collection, Iterator, Mapping, Sequence
25from contextlib import suppress
26from functools import cached_property, partial, update_wrapper
27from typing import TYPE_CHECKING, Any, ClassVar, Generic, ParamSpec, Protocol, TypeVar, cast, overload
28
29import attr
30import typing_extensions
31
32from airflow.sdk import TriggerRule, timezone
33from airflow.sdk.bases.operator import (
34 BaseOperator,
35 coerce_resources,
36 coerce_timedelta,
37 get_merged_defaults,
38 parse_retries,
39)
40from airflow.sdk.definitions._internal.contextmanager import DagContext, TaskGroupContext
41from airflow.sdk.definitions._internal.decorators import remove_task_decorator
42from airflow.sdk.definitions._internal.expandinput import (
43 EXPAND_INPUT_EMPTY,
44 DictOfListsExpandInput,
45 ListOfDictsExpandInput,
46 is_mappable,
47)
48from airflow.sdk.definitions._internal.types import NOTSET
49from airflow.sdk.definitions.asset import Asset
50from airflow.sdk.definitions.context import KNOWN_CONTEXT_KEYS
51from airflow.sdk.definitions.mappedoperator import (
52 MappedOperator,
53 ensure_xcomarg_return_value,
54 prevent_duplicates,
55)
56from airflow.sdk.definitions.xcom_arg import XComArg
57
58if TYPE_CHECKING:
59 from airflow.sdk.definitions._internal.expandinput import (
60 ExpandInput,
61 OperatorExpandArgument,
62 OperatorExpandKwargsArgument,
63 )
64 from airflow.sdk.definitions.context import Context
65 from airflow.sdk.definitions.dag import DAG
66 from airflow.sdk.definitions.mappedoperator import ValidationSource
67 from airflow.sdk.definitions.taskgroup import TaskGroup
68
69
70class ExpandableFactory(Protocol):
71 """
72 Protocol providing inspection against wrapped function.
73
74 This is used in ``validate_expand_kwargs`` and implemented by function
75 decorators like ``@task`` and ``@task_group``.
76
77 :meta private:
78 """
79
80 function: Callable
81
82 @cached_property
83 def function_signature(self) -> inspect.Signature:
84 return inspect.signature(self.function)
85
86 @cached_property
87 def _mappable_function_argument_names(self) -> set[str]:
88 """Arguments that can be mapped against."""
89 return set(self.function_signature.parameters)
90
91 def _validate_arg_names(self, func: ValidationSource, kwargs: dict[str, Any]) -> None:
92 """Ensure that all arguments passed to operator-mapping functions are accounted for."""
93 parameters = self.function_signature.parameters
94 if any(v.kind == inspect.Parameter.VAR_KEYWORD for v in parameters.values()):
95 return
96 kwargs_left = kwargs.copy()
97 for arg_name in self._mappable_function_argument_names:
98 value = kwargs_left.pop(arg_name, NOTSET)
99 if func == "expand" and value is not NOTSET and not is_mappable(value):
100 tname = type(value).__name__
101 raise ValueError(
102 f"expand() got an unexpected type {tname!r} for keyword argument {arg_name!r}"
103 )
104 if len(kwargs_left) == 1:
105 raise TypeError(f"{func}() got an unexpected keyword argument {next(iter(kwargs_left))!r}")
106 if kwargs_left:
107 names = ", ".join(repr(n) for n in kwargs_left)
108 raise TypeError(f"{func}() got unexpected keyword arguments {names}")
109
110
111def get_unique_task_id(
112 task_id: str,
113 dag: DAG | None = None,
114 task_group: TaskGroup | None = None,
115) -> str:
116 """
117 Generate unique task id given a Dag (or if run in a Dag context).
118
119 IDs are generated by appending a unique number to the end of
120 the original task id.
121
122 Example:
123 task_id
124 task_id__1
125 task_id__2
126 ...
127 task_id__20
128 """
129 dag = dag or DagContext.get_current()
130 if not dag:
131 return task_id
132
133 # We need to check if we are in the context of TaskGroup as the task_id may
134 # already be altered
135 task_group = task_group or TaskGroupContext.get_current(dag)
136 tg_task_id = task_group.child_id(task_id) if task_group else task_id
137
138 if tg_task_id not in dag.task_ids:
139 return task_id
140
141 def _find_id_suffixes(dag: DAG) -> Iterator[int]:
142 prefix = re.split(r"__\d+$", tg_task_id)[0]
143 for task_id in dag.task_ids:
144 match = re.match(rf"^{prefix}__(\d+)$", task_id)
145 if match:
146 yield int(match.group(1))
147 yield 0 # Default if there's no matching task ID.
148
149 core = re.split(r"__\d+$", task_id)[0]
150 return f"{core}__{max(_find_id_suffixes(dag)) + 1}"
151
152
153def unwrap_partial(fn: Callable) -> Callable:
154 while isinstance(fn, partial):
155 fn = fn.func
156 return fn
157
158
159def unwrap_callable(func):
160 from airflow.sdk.definitions.mappedoperator import OperatorPartial
161
162 if isinstance(func, (_TaskDecorator, OperatorPartial)):
163 func = getattr(func, "function", getattr(func, "_func", func))
164
165 func = unwrap_partial(func)
166
167 with suppress(Exception):
168 func = inspect.unwrap(func)
169
170 return func
171
172
173def is_async_callable(func):
174 """Detect if a callable (possibly wrapped) is an async function."""
175 func = unwrap_callable(func)
176
177 if not callable(func):
178 return False
179
180 # Direct async function
181 if inspect.iscoroutinefunction(func):
182 return True
183
184 # Callable object with async __call__
185 if not inspect.isfunction(func):
186 call = type(func).__call__ # Bandit-safe
187 with suppress(Exception):
188 call = inspect.unwrap(call)
189 if inspect.iscoroutinefunction(call):
190 return True
191
192 return False
193
194
195class DecoratedOperator(BaseOperator):
196 """
197 Wraps a Python callable and captures args/kwargs when called for execution.
198
199 :param python_callable: A reference to an object that is callable
200 :param op_kwargs: a dictionary of keyword arguments that will get unpacked
201 in your function (templated)
202 :param op_args: a list of positional arguments that will get unpacked when
203 calling your callable (templated)
204 :param multiple_outputs: If set to True, the decorated function's return value will be unrolled to
205 multiple XCom values. Dict will unroll to XCom values with its keys as XCom keys. Defaults to False.
206 :param kwargs_to_upstream: For certain operators, we might need to upstream certain arguments
207 that would otherwise be absorbed by the DecoratedOperator (for example python_callable for the
208 PythonOperator). This gives a user the option to upstream kwargs as needed.
209 """
210
211 template_fields: Sequence[str] = ("op_args", "op_kwargs")
212 template_fields_renderers = {"op_args": "py", "op_kwargs": "py"}
213
214 # since we won't mutate the arguments, we should just do the shallow copy
215 # there are some cases we can't deepcopy the objects (e.g protobuf).
216 shallow_copy_attrs: Sequence[str] = ("python_callable",)
217
218 def __init__(
219 self,
220 *,
221 python_callable: Callable,
222 task_id: str,
223 op_args: Collection[Any] | None = None,
224 op_kwargs: Mapping[str, Any] | None = None,
225 kwargs_to_upstream: dict[str, Any] | None = None,
226 **kwargs,
227 ) -> None:
228 if not getattr(self, "_BaseOperator__from_mapped", False):
229 # If we are being created from calling unmap(), then don't mangle the task id
230 task_id = get_unique_task_id(task_id, kwargs.get("dag"), kwargs.get("task_group"))
231 self.python_callable = python_callable
232 kwargs_to_upstream = kwargs_to_upstream or {}
233 op_args = op_args or []
234 op_kwargs = op_kwargs or {}
235
236 # Check the decorated function's signature. We go through the argument
237 # list and "fill in" defaults to arguments that are known context keys,
238 # since values for those will be provided when the task is run. Since
239 # we're not actually running the function, None is good enough here.
240 signature = inspect.signature(python_callable)
241
242 # Don't allow context argument defaults other than None to avoid ambiguities.
243 faulty_parameters = [
244 param.name
245 for param in signature.parameters.values()
246 if param.name in KNOWN_CONTEXT_KEYS and param.default not in (None, inspect.Parameter.empty)
247 ]
248 if faulty_parameters:
249 message = f"Context key parameter {faulty_parameters[0]} can't have a default other than None"
250 raise ValueError(message)
251
252 parameters = [
253 param.replace(default=None) if param.name in KNOWN_CONTEXT_KEYS else param
254 for param in signature.parameters.values()
255 ]
256 try:
257 signature = signature.replace(parameters=parameters)
258 except ValueError as err:
259 message = textwrap.dedent(
260 f"""
261 The function signature broke while assigning defaults to context key parameters.
262
263 The decorator is replacing the signature
264 > {python_callable.__name__}({", ".join(str(param) for param in signature.parameters.values())})
265
266 with
267 > {python_callable.__name__}({", ".join(str(param) for param in parameters)})
268
269 which isn't valid: {err}
270 """
271 )
272 raise ValueError(message) from err
273
274 # Check that arguments can be binded. There's a slight difference when
275 # we do validation for task-mapping: Since there's no guarantee we can
276 # receive enough arguments at parse time, we use bind_partial to simply
277 # check all the arguments we know are valid. Whether these are enough
278 # can only be known at execution time, when unmapping happens, and this
279 # is called without the _airflow_mapped_validation_only flag.
280 if kwargs.get("_airflow_mapped_validation_only"):
281 signature.bind_partial(*op_args, **op_kwargs)
282 else:
283 signature.bind(*op_args, **op_kwargs)
284
285 self.op_args = op_args
286 self.op_kwargs = op_kwargs
287 super().__init__(task_id=task_id, **kwargs_to_upstream, **kwargs)
288
289 @property
290 def is_async(self) -> bool:
291 return is_async_callable(self.python_callable)
292
293 def execute(self, context: Context):
294 # todo make this more generic (move to prepare_lineage) so it deals with non taskflow operators
295 # as well
296 for arg in itertools.chain(self.op_args, self.op_kwargs.values()):
297 if isinstance(arg, Asset):
298 self.inlets.append(arg)
299 return_value = super().execute(context)
300 return self._handle_output(return_value=return_value)
301
302 def _handle_output(self, return_value: Any):
303 """
304 Handle logic for whether a decorator needs to push a single return value or multiple return values.
305
306 It sets outlets if any assets are found in the returned value(s)
307
308 :param return_value:
309 :param context:
310 :param xcom_push:
311 """
312 if isinstance(return_value, Asset):
313 self.outlets.append(return_value)
314 if isinstance(return_value, list):
315 for item in return_value:
316 if isinstance(item, Asset):
317 self.outlets.append(item)
318 return return_value
319
320 def _hook_apply_defaults(self, *args, **kwargs):
321 if "python_callable" not in kwargs:
322 return args, kwargs
323
324 python_callable = kwargs["python_callable"]
325 default_args = kwargs.get("default_args") or {}
326 op_kwargs = kwargs.get("op_kwargs") or {}
327 f_sig = inspect.signature(python_callable)
328 for arg in f_sig.parameters:
329 if arg not in op_kwargs and arg in default_args:
330 op_kwargs[arg] = default_args[arg]
331 kwargs["op_kwargs"] = op_kwargs
332 return args, kwargs
333
334 def get_python_source(self):
335 raw_source = inspect.getsource(self.python_callable)
336 raw_source_lines = [line for line in raw_source.splitlines() if not line.strip().startswith("#")]
337 res = textwrap.dedent("\n".join(raw_source_lines)) + "\n"
338 res = remove_task_decorator(res, self.custom_operator_name)
339 return res
340
341
342FParams = ParamSpec("FParams")
343
344FReturn = TypeVar("FReturn")
345
346OperatorSubclass = TypeVar("OperatorSubclass", bound="BaseOperator")
347
348
349@attr.define(slots=False)
350class _TaskDecorator(ExpandableFactory, Generic[FParams, FReturn, OperatorSubclass]):
351 """
352 Helper class for providing dynamic task mapping to decorated functions.
353
354 ``task_decorator_factory`` returns an instance of this, instead of just a plain wrapped function.
355
356 :meta private:
357 """
358
359 function: Callable[FParams, FReturn] = attr.ib(validator=attr.validators.is_callable())
360 operator_class: type[OperatorSubclass]
361 multiple_outputs: bool = attr.ib()
362 kwargs: dict[str, Any] = attr.ib(factory=dict)
363
364 decorator_name: str = attr.ib(repr=False, default="task")
365
366 _airflow_is_task_decorator: ClassVar[bool] = True
367 is_setup: bool = False
368 is_teardown: bool = False
369 on_failure_fail_dagrun: bool = False
370
371 # This is set in __attrs_post_init__ by update_wrapper. Provided here for type hints.
372 __wrapped__: Callable[FParams, FReturn] = attr.ib(init=False)
373
374 @multiple_outputs.default
375 def _infer_multiple_outputs(self):
376 if "return" not in self.function.__annotations__:
377 # No return type annotation, nothing to infer
378 return False
379
380 try:
381 # We only care about the return annotation, not anything about the parameters
382 def fake(): ...
383
384 fake.__annotations__ = {"return": self.function.__annotations__["return"]}
385
386 return_type = typing_extensions.get_type_hints(fake, self.function.__globals__).get("return", Any)
387 except NameError as e:
388 warnings.warn(
389 f"Cannot infer multiple_outputs for TaskFlow function {self.function.__name__!r} with forward"
390 f" type references that are not imported. (Error was {e})",
391 stacklevel=4,
392 )
393 return False
394 except TypeError: # Can't evaluate return type.
395 return False
396 ttype = getattr(return_type, "__origin__", return_type)
397 return isinstance(ttype, type) and issubclass(ttype, Mapping)
398
399 def __attrs_post_init__(self):
400 if "self" in self.function_signature.parameters:
401 raise TypeError(f"@{self.decorator_name} does not support methods")
402 self.kwargs.setdefault("task_id", self.function.__name__)
403 update_wrapper(self, self.function)
404
405 def __call__(self, *args: FParams.args, **kwargs: FParams.kwargs) -> XComArg:
406 if self.is_teardown:
407 if "trigger_rule" in self.kwargs:
408 raise ValueError("Trigger rule not configurable for teardown tasks.")
409 self.kwargs.update(trigger_rule=TriggerRule.ALL_DONE_SETUP_SUCCESS)
410 on_failure_fail_dagrun = self.kwargs.pop("on_failure_fail_dagrun", self.on_failure_fail_dagrun)
411 op = self.operator_class(
412 python_callable=self.function,
413 op_args=args,
414 op_kwargs=kwargs,
415 multiple_outputs=self.multiple_outputs,
416 **self.kwargs,
417 )
418 op.is_setup = self.is_setup
419 op.is_teardown = self.is_teardown
420 op.on_failure_fail_dagrun = on_failure_fail_dagrun
421 op_doc_attrs = [op.doc, op.doc_json, op.doc_md, op.doc_rst, op.doc_yaml]
422 # Set the task's doc_md to the function's docstring if it exists and no other doc* args are set.
423 if self.function.__doc__ and not any(op_doc_attrs):
424 op.doc_md = self.function.__doc__
425 return XComArg(op)
426
427 def _validate_arg_names(self, func: ValidationSource, kwargs: dict[str, Any]):
428 # Ensure that context variables are not shadowed.
429 context_keys_being_mapped = KNOWN_CONTEXT_KEYS.intersection(kwargs)
430 if len(context_keys_being_mapped) == 1:
431 (name,) = context_keys_being_mapped
432 raise ValueError(f"cannot call {func}() on task context variable {name!r}")
433 if context_keys_being_mapped:
434 names = ", ".join(repr(n) for n in context_keys_being_mapped)
435 raise ValueError(f"cannot call {func}() on task context variables {names}")
436
437 super()._validate_arg_names(func, kwargs)
438
439 def expand(self, **map_kwargs: OperatorExpandArgument) -> XComArg:
440 if self.kwargs.get("trigger_rule") == TriggerRule.ALWAYS and any(
441 [isinstance(expanded, XComArg) for expanded in map_kwargs.values()]
442 ):
443 raise ValueError(
444 "Task-generated mapping within a task using 'expand' is not allowed with trigger rule 'always'."
445 )
446 if not map_kwargs:
447 raise TypeError("no arguments to expand against")
448 self._validate_arg_names("expand", map_kwargs)
449 prevent_duplicates(self.kwargs, map_kwargs, fail_reason="mapping already partial")
450 # Since the input is already checked at parse time, we can set strict
451 # to False to skip the checks on execution.
452 if self.is_teardown:
453 if "trigger_rule" in self.kwargs:
454 raise ValueError("Trigger rule not configurable for teardown tasks.")
455 self.kwargs.update(trigger_rule=TriggerRule.ALL_DONE_SETUP_SUCCESS)
456 return self._expand(DictOfListsExpandInput(map_kwargs), strict=False)
457
458 def expand_kwargs(self, kwargs: OperatorExpandKwargsArgument, *, strict: bool = True) -> XComArg:
459 if (
460 self.kwargs.get("trigger_rule") == TriggerRule.ALWAYS
461 and not isinstance(kwargs, XComArg)
462 and any(
463 [
464 isinstance(v, XComArg)
465 for kwarg in kwargs
466 if not isinstance(kwarg, XComArg)
467 for v in kwarg.values()
468 ]
469 )
470 ):
471 raise ValueError(
472 "Task-generated mapping within a task using 'expand_kwargs' is not allowed with trigger rule 'always'."
473 )
474 if isinstance(kwargs, Sequence):
475 for item in kwargs:
476 if not isinstance(item, (XComArg, Mapping)):
477 raise TypeError(f"expected XComArg or list[dict], not {type(kwargs).__name__}")
478 elif not isinstance(kwargs, XComArg):
479 raise TypeError(f"expected XComArg or list[dict], not {type(kwargs).__name__}")
480 return self._expand(ListOfDictsExpandInput(kwargs), strict=strict)
481
482 def _expand(self, expand_input: ExpandInput, *, strict: bool) -> XComArg:
483 ensure_xcomarg_return_value(expand_input.value)
484
485 task_kwargs = self.kwargs.copy()
486 dag = task_kwargs.pop("dag", None) or DagContext.get_current()
487 task_group = task_kwargs.pop("task_group", None) or TaskGroupContext.get_current(dag)
488
489 default_args, partial_params = get_merged_defaults(
490 dag=dag,
491 task_group=task_group,
492 task_params=task_kwargs.pop("params", None),
493 task_default_args=task_kwargs.pop("default_args", None),
494 )
495 partial_kwargs: dict[str, Any] = {
496 "is_setup": self.is_setup,
497 "is_teardown": self.is_teardown,
498 "on_failure_fail_dagrun": self.on_failure_fail_dagrun,
499 }
500 base_signature = inspect.signature(BaseOperator)
501 ignore = {
502 "default_args", # This is target we are working on now.
503 "kwargs", # A common name for a keyword argument.
504 "do_xcom_push", # In the same boat as `multiple_outputs`
505 "multiple_outputs", # We will use `self.multiple_outputs` instead.
506 "params", # Already handled above `partial_params`.
507 "task_concurrency", # Deprecated(replaced by `max_active_tis_per_dag`).
508 }
509 partial_keys = set(base_signature.parameters) - ignore
510 partial_kwargs.update({key: value for key, value in default_args.items() if key in partial_keys})
511 partial_kwargs.update(task_kwargs)
512
513 task_id = get_unique_task_id(partial_kwargs.pop("task_id"), dag, task_group)
514 if task_group:
515 task_id = task_group.child_id(task_id)
516
517 # Logic here should be kept in sync with BaseOperatorMeta.partial().
518 if partial_kwargs.get("wait_for_downstream"):
519 partial_kwargs["depends_on_past"] = True
520 start_date = timezone.convert_to_utc(partial_kwargs.pop("start_date", None))
521 end_date = timezone.convert_to_utc(partial_kwargs.pop("end_date", None))
522 if "pool_slots" in partial_kwargs:
523 if partial_kwargs["pool_slots"] < 1:
524 dag_str = ""
525 if dag:
526 dag_str = f" in dag {dag.dag_id}"
527 raise ValueError(f"pool slots for {task_id}{dag_str} cannot be less than 1")
528
529 for fld, convert in (
530 ("retries", parse_retries),
531 ("retry_delay", coerce_timedelta),
532 ("max_retry_delay", coerce_timedelta),
533 ("resources", coerce_resources),
534 ):
535 if (v := partial_kwargs.get(fld, NOTSET)) is not NOTSET:
536 partial_kwargs[fld] = convert(v)
537
538 partial_kwargs.setdefault("executor_config", {})
539 partial_kwargs.setdefault("op_args", [])
540 partial_kwargs.setdefault("op_kwargs", {})
541
542 # Mypy does not work well with a subclassed attrs class :(
543 _MappedOperator = cast("Any", DecoratedMappedOperator)
544
545 try:
546 operator_name = self.operator_class.custom_operator_name # type: ignore
547 except AttributeError:
548 operator_name = self.operator_class.__name__
549
550 operator = _MappedOperator(
551 operator_class=self.operator_class,
552 expand_input=EXPAND_INPUT_EMPTY, # Don't use this; mapped values go to op_kwargs_expand_input.
553 partial_kwargs=partial_kwargs,
554 task_id=task_id,
555 params=partial_params,
556 operator_extra_links=self.operator_class.operator_extra_links,
557 template_ext=self.operator_class.template_ext,
558 template_fields=self.operator_class.template_fields,
559 template_fields_renderers=self.operator_class.template_fields_renderers,
560 ui_color=self.operator_class.ui_color,
561 ui_fgcolor=self.operator_class.ui_fgcolor,
562 is_empty=False,
563 is_sensor=self.operator_class._is_sensor,
564 can_skip_downstream=self.operator_class._can_skip_downstream,
565 task_module=self.operator_class.__module__,
566 task_type=self.operator_class.__name__,
567 operator_name=operator_name,
568 dag=dag,
569 task_group=task_group,
570 start_date=start_date,
571 end_date=end_date,
572 multiple_outputs=self.multiple_outputs,
573 python_callable=self.function,
574 op_kwargs_expand_input=expand_input,
575 disallow_kwargs_override=strict,
576 # Different from classic operators, kwargs passed to a taskflow
577 # task's expand() contribute to the op_kwargs operator argument, not
578 # the operator arguments themselves, and should expand against it.
579 expand_input_attr="op_kwargs_expand_input",
580 start_trigger_args=self.operator_class.start_trigger_args,
581 start_from_trigger=self.operator_class.start_from_trigger,
582 )
583 return XComArg(operator=operator)
584
585 def partial(self, **kwargs: Any) -> _TaskDecorator[FParams, FReturn, OperatorSubclass]:
586 self._validate_arg_names("partial", kwargs)
587 old_kwargs = self.kwargs.get("op_kwargs", {})
588 prevent_duplicates(old_kwargs, kwargs, fail_reason="duplicate partial")
589 kwargs.update(old_kwargs)
590 return attr.evolve(self, kwargs={**self.kwargs, "op_kwargs": kwargs})
591
592 def override(self, **kwargs: Any) -> _TaskDecorator[FParams, FReturn, OperatorSubclass]:
593 result = attr.evolve(self, kwargs={**self.kwargs, **kwargs})
594 setattr(result, "is_setup", self.is_setup)
595 setattr(result, "is_teardown", self.is_teardown)
596 setattr(result, "on_failure_fail_dagrun", self.on_failure_fail_dagrun)
597 return result
598
599
600@attr.define(kw_only=True, repr=False)
601class DecoratedMappedOperator(MappedOperator):
602 """MappedOperator implementation for @task-decorated task function."""
603
604 multiple_outputs: bool
605 python_callable: Callable
606
607 # We can't save these in expand_input because op_kwargs need to be present
608 # in partial_kwargs, and MappedOperator prevents duplication.
609 op_kwargs_expand_input: ExpandInput
610
611 def __hash__(self):
612 return id(self)
613
614 def _expand_mapped_kwargs(self, context: Mapping[str, Any]) -> tuple[Mapping[str, Any], set[int]]:
615 # We only use op_kwargs_expand_input so this must always be empty.
616 if self.expand_input is not EXPAND_INPUT_EMPTY:
617 raise AssertionError(f"unexpected expand_input: {self.expand_input}")
618 op_kwargs, resolved_oids = super()._expand_mapped_kwargs(context)
619 return {"op_kwargs": op_kwargs}, resolved_oids
620
621 def _get_unmap_kwargs(self, mapped_kwargs: Mapping[str, Any], *, strict: bool) -> dict[str, Any]:
622 partial_op_kwargs = self.partial_kwargs["op_kwargs"]
623 mapped_op_kwargs = mapped_kwargs["op_kwargs"]
624
625 if strict:
626 prevent_duplicates(partial_op_kwargs, mapped_op_kwargs, fail_reason="mapping already partial")
627
628 kwargs = {
629 "multiple_outputs": self.multiple_outputs,
630 "python_callable": self.python_callable,
631 "op_kwargs": {**partial_op_kwargs, **mapped_op_kwargs},
632 }
633 return super()._get_unmap_kwargs(kwargs, strict=False)
634
635
636class Task(Protocol, Generic[FParams, FReturn]):
637 """
638 Declaration of a @task-decorated callable for type-checking.
639
640 An instance of this type inherits the call signature of the decorated
641 function wrapped in it (not *exactly* since it actually returns an XComArg,
642 but there's no way to express that right now), and provides two additional
643 methods for task-mapping.
644
645 This type is implemented by ``_TaskDecorator`` at runtime.
646 """
647
648 __call__: Callable[FParams, XComArg]
649
650 function: Callable[FParams, FReturn]
651
652 @property
653 def __wrapped__(self) -> Callable[FParams, FReturn]: ...
654
655 def partial(self, **kwargs: Any) -> Task[FParams, FReturn]: ...
656
657 def expand(self, **kwargs: OperatorExpandArgument) -> XComArg: ...
658
659 def expand_kwargs(self, kwargs: OperatorExpandKwargsArgument, *, strict: bool = True) -> XComArg: ...
660
661 def override(self, **kwargs: Any) -> Task[FParams, FReturn]: ...
662
663
664class TaskDecorator(Protocol):
665 """Type declaration for ``task_decorator_factory`` return type."""
666
667 @overload
668 def __call__( # type: ignore[misc]
669 self,
670 python_callable: Callable[FParams, FReturn],
671 ) -> Task[FParams, FReturn]:
672 """For the "bare decorator" ``@task`` case."""
673
674 @overload
675 def __call__(
676 self,
677 *,
678 multiple_outputs: bool | None = None,
679 **kwargs: Any,
680 ) -> Callable[[Callable[FParams, FReturn]], Task[FParams, FReturn]]:
681 """For the decorator factory ``@task()`` case."""
682
683 def override(self, **kwargs: Any) -> Task[FParams, FReturn]: ...
684
685
686def task_decorator_factory(
687 python_callable: Callable | None = None,
688 *,
689 multiple_outputs: bool | None = None,
690 decorated_operator_class: type[BaseOperator],
691 **kwargs,
692) -> TaskDecorator:
693 """
694 Generate a wrapper that wraps a function into an Airflow operator.
695
696 Can be reused in a single Dag.
697
698 :param python_callable: Function to decorate.
699 :param multiple_outputs: If set to True, the decorated function's return
700 value will be unrolled to multiple XCom values. Dict will unroll to XCom
701 values with its keys as XCom keys. If set to False (default), only at
702 most one XCom value is pushed.
703 :param decorated_operator_class: The operator that executes the logic needed
704 to run the python function in the correct environment.
705
706 Other kwargs are directly forwarded to the underlying operator class when
707 it's instantiated.
708 """
709 if multiple_outputs is None:
710 multiple_outputs = cast("bool", attr.NOTHING)
711 if python_callable:
712 decorator = _TaskDecorator(
713 function=python_callable,
714 multiple_outputs=multiple_outputs,
715 operator_class=decorated_operator_class,
716 kwargs=kwargs,
717 )
718 return cast("TaskDecorator", decorator)
719 if python_callable is not None:
720 raise TypeError("No args allowed while using @task, use kwargs instead")
721
722 def decorator_factory(python_callable):
723 return _TaskDecorator(
724 function=python_callable,
725 multiple_outputs=multiple_outputs,
726 operator_class=decorated_operator_class,
727 kwargs=kwargs,
728 )
729
730 return cast("TaskDecorator", decorator_factory)