1#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18from __future__ import annotations
19
20import contextlib
21import copy
22import warnings
23from collections.abc import Collection, Iterable, Iterator, Mapping, Sequence
24from typing import TYPE_CHECKING, Any, ClassVar, Literal, TypeGuard
25
26import attrs
27import methodtools
28from lazy_object_proxy import Proxy
29
30from airflow.sdk.bases.xcom import BaseXCom
31from airflow.sdk.definitions._internal.abstractoperator import (
32 DEFAULT_EXECUTOR,
33 DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST,
34 DEFAULT_OWNER,
35 DEFAULT_POOL_NAME,
36 DEFAULT_POOL_SLOTS,
37 DEFAULT_PRIORITY_WEIGHT,
38 DEFAULT_QUEUE,
39 DEFAULT_RETRIES,
40 DEFAULT_RETRY_DELAY,
41 DEFAULT_TRIGGER_RULE,
42 DEFAULT_WAIT_FOR_PAST_DEPENDS_BEFORE_SKIPPING,
43 DEFAULT_WEIGHT_RULE,
44 AbstractOperator,
45 TaskStateChangeCallbackAttrType,
46)
47from airflow.sdk.definitions._internal.expandinput import (
48 DictOfListsExpandInput,
49 ListOfDictsExpandInput,
50 is_mappable,
51)
52from airflow.sdk.definitions._internal.types import NOTSET
53from airflow.serialization.enums import DagAttributeTypes
54
55if TYPE_CHECKING:
56 import datetime
57
58 import jinja2 # Slow import.
59 import pendulum
60
61 from airflow.sdk import DAG, BaseOperator, BaseOperatorLink, Context, TaskGroup, TriggerRule, XComArg
62 from airflow.sdk.definitions._internal.expandinput import (
63 ExpandInput,
64 OperatorExpandArgument,
65 OperatorExpandKwargsArgument,
66 )
67 from airflow.sdk.definitions.operator_resources import Resources
68 from airflow.sdk.definitions.param import ParamsDict
69 from airflow.task.priority_strategy import PriorityWeightStrategy
70 from airflow.triggers.base import StartTriggerArgs
71
72ValidationSource = Literal["expand"] | Literal["partial"]
73
74
75def validate_mapping_kwargs(op: type[BaseOperator], func: ValidationSource, value: dict[str, Any]) -> None:
76 # use a dict so order of args is same as code order
77 unknown_args = value.copy()
78 for klass in op.mro():
79 init = klass.__init__ # type: ignore[misc]
80 try:
81 param_names = init._BaseOperatorMeta__param_names
82 except AttributeError:
83 continue
84 for name in param_names:
85 value = unknown_args.pop(name, NOTSET)
86 if func != "expand":
87 continue
88 if value is NOTSET:
89 continue
90 if is_mappable(value):
91 continue
92 type_name = type(value).__name__
93 error = f"{op.__name__}.expand() got an unexpected type {type_name!r} for keyword argument {name}"
94 raise ValueError(error)
95 if not unknown_args:
96 return # If we have no args left to check: stop looking at the MRO chain.
97
98 if len(unknown_args) == 1:
99 error = f"an unexpected keyword argument {unknown_args.popitem()[0]!r}"
100 else:
101 names = ", ".join(repr(n) for n in unknown_args)
102 error = f"unexpected keyword arguments {names}"
103 raise TypeError(f"{op.__name__}.{func}() got {error}")
104
105
106def _is_container(obj: Any) -> bool:
107 """Test if an object is a container (iterable) but not a string."""
108 if isinstance(obj, Proxy):
109 # Proxy of any object is considered a container because it implements __iter__
110 # to forward the call to the lazily initialized object
111 # Unwrap Proxy before checking __iter__ to evaluate the proxied object
112 obj = obj.__wrapped__
113 return hasattr(obj, "__iter__") and not isinstance(obj, str)
114
115
116def ensure_xcomarg_return_value(arg: Any) -> None:
117 from airflow.sdk.definitions.xcom_arg import XComArg
118
119 if isinstance(arg, XComArg):
120 for operator, key in arg.iter_references():
121 if key != BaseXCom.XCOM_RETURN_KEY:
122 raise ValueError(f"cannot map over XCom with custom key {key!r} from {operator}")
123 elif not _is_container(arg):
124 return
125 elif isinstance(arg, Mapping):
126 for v in arg.values():
127 ensure_xcomarg_return_value(v)
128 elif isinstance(arg, Iterable):
129 for v in arg:
130 ensure_xcomarg_return_value(v)
131
132
133def is_mappable_value(value: Any) -> TypeGuard[Collection]:
134 """
135 Whether a value can be used for task mapping.
136
137 We only allow collections with guaranteed ordering, but exclude character
138 sequences since that's usually not what users would expect to be mappable.
139
140 :meta private:
141 """
142 if not isinstance(value, (Sequence, dict)):
143 return False
144 if isinstance(value, (bytearray, bytes, str)):
145 return False
146 return True
147
148
149def prevent_duplicates(kwargs1: dict[str, Any], kwargs2: Mapping[str, Any], *, fail_reason: str) -> None:
150 """
151 Ensure *kwargs1* and *kwargs2* do not contain common keys.
152
153 :raises TypeError: If common keys are found.
154 """
155 duplicated_keys = set(kwargs1).intersection(kwargs2)
156 if not duplicated_keys:
157 return
158 if len(duplicated_keys) == 1:
159 raise TypeError(f"{fail_reason} argument: {duplicated_keys.pop()}")
160 duplicated_keys_display = ", ".join(sorted(duplicated_keys))
161 raise TypeError(f"{fail_reason} arguments: {duplicated_keys_display}")
162
163
164@attrs.define(kw_only=True, repr=False)
165class OperatorPartial:
166 """
167 An "intermediate state" returned by ``BaseOperator.partial()``.
168
169 This only exists at Dag-parsing time; the only intended usage is for the
170 user to call ``.expand()`` on it at some point (usually in a method chain) to
171 create a ``MappedOperator`` to add into the Dag.
172 """
173
174 operator_class: type[BaseOperator]
175 kwargs: dict[str, Any]
176 params: ParamsDict | dict
177
178 _expand_called: bool = False # Set when expand() is called to ease user debugging.
179
180 def __attrs_post_init__(self):
181 validate_mapping_kwargs(self.operator_class, "partial", self.kwargs)
182
183 def __repr__(self) -> str:
184 args = ", ".join(f"{k}={v!r}" for k, v in self.kwargs.items())
185 return f"{self.operator_class.__name__}.partial({args})"
186
187 def __del__(self):
188 if not self._expand_called:
189 try:
190 task_id = repr(self.kwargs["task_id"])
191 except KeyError:
192 task_id = f"at {hex(id(self))}"
193 warnings.warn(f"Task {task_id} was never mapped!", category=UserWarning, stacklevel=1)
194
195 def expand(self, **mapped_kwargs: OperatorExpandArgument) -> MappedOperator:
196 if not mapped_kwargs:
197 raise TypeError("no arguments to expand against")
198 validate_mapping_kwargs(self.operator_class, "expand", mapped_kwargs)
199 prevent_duplicates(self.kwargs, mapped_kwargs, fail_reason="unmappable or already specified")
200 # Since the input is already checked at parse time, we can set strict
201 # to False to skip the checks on execution.
202 return self._expand(DictOfListsExpandInput(mapped_kwargs), strict=False)
203
204 def expand_kwargs(self, kwargs: OperatorExpandKwargsArgument, *, strict: bool = True) -> MappedOperator:
205 from airflow.sdk.definitions.xcom_arg import XComArg
206
207 if isinstance(kwargs, Sequence):
208 for item in kwargs:
209 if not isinstance(item, (XComArg, Mapping)):
210 raise TypeError(f"expected XComArg or list[dict], not {type(kwargs).__name__}")
211 elif not isinstance(kwargs, XComArg):
212 raise TypeError(f"expected XComArg or list[dict], not {type(kwargs).__name__}")
213 return self._expand(ListOfDictsExpandInput(kwargs), strict=strict)
214
215 def _expand(self, expand_input: ExpandInput, *, strict: bool) -> MappedOperator:
216 from airflow.providers.standard.operators.empty import EmptyOperator
217 from airflow.providers.standard.utils.skipmixin import SkipMixin
218 from airflow.sdk import BaseSensorOperator
219
220 self._expand_called = True
221 ensure_xcomarg_return_value(expand_input.value)
222
223 partial_kwargs = self.kwargs.copy()
224 task_id = partial_kwargs.pop("task_id")
225 dag = partial_kwargs.pop("dag")
226 task_group = partial_kwargs.pop("task_group")
227 start_date = partial_kwargs.pop("start_date", None)
228 end_date = partial_kwargs.pop("end_date", None)
229
230 try:
231 operator_name = self.operator_class.custom_operator_name # type: ignore
232 except AttributeError:
233 operator_name = self.operator_class.__name__
234
235 op = MappedOperator(
236 operator_class=self.operator_class,
237 expand_input=expand_input,
238 partial_kwargs=partial_kwargs,
239 task_id=task_id,
240 params=self.params,
241 operator_extra_links=self.operator_class.operator_extra_links,
242 template_ext=self.operator_class.template_ext,
243 template_fields=self.operator_class.template_fields,
244 template_fields_renderers=self.operator_class.template_fields_renderers,
245 ui_color=self.operator_class.ui_color,
246 ui_fgcolor=self.operator_class.ui_fgcolor,
247 is_empty=issubclass(self.operator_class, EmptyOperator),
248 is_sensor=issubclass(self.operator_class, BaseSensorOperator),
249 can_skip_downstream=issubclass(self.operator_class, SkipMixin),
250 task_module=self.operator_class.__module__,
251 task_type=self.operator_class.__name__,
252 operator_name=operator_name,
253 dag=dag,
254 task_group=task_group,
255 start_date=start_date,
256 end_date=end_date,
257 disallow_kwargs_override=strict,
258 # For classic operators, this points to expand_input because kwargs
259 # to BaseOperator.expand() contribute to operator arguments.
260 expand_input_attr="expand_input",
261 # TODO: Move these to task SDK's BaseOperator and remove getattr
262 start_trigger_args=getattr(self.operator_class, "start_trigger_args", None),
263 start_from_trigger=bool(getattr(self.operator_class, "start_from_trigger", False)),
264 )
265 return op
266
267
268@attrs.define(
269 kw_only=True,
270 # Disable custom __getstate__ and __setstate__ generation since it interacts
271 # badly with Airflow's Dag serialization and pickling. When a mapped task is
272 # deserialized, subclasses are coerced into MappedOperator, but when it goes
273 # through Dag pickling, all attributes defined in the subclasses are dropped
274 # by attrs's custom state management. Since attrs does not do anything too
275 # special here (the logic is only important for slots=True), we use Python's
276 # built-in implementation, which works (as proven by good old BaseOperator).
277 getstate_setstate=False,
278)
279class MappedOperator(AbstractOperator):
280 """Object representing a mapped operator in a Dag."""
281
282 operator_class: type[BaseOperator]
283
284 _is_mapped: bool = attrs.field(init=False, default=True)
285
286 expand_input: ExpandInput
287 partial_kwargs: dict[str, Any]
288
289 # Needed for serialization.
290 task_id: str
291 params: ParamsDict | dict
292 operator_extra_links: Collection[BaseOperatorLink]
293 template_ext: Sequence[str]
294 template_fields: Collection[str]
295 template_fields_renderers: dict[str, str]
296 ui_color: str
297 ui_fgcolor: str
298 _is_empty: bool = attrs.field(alias="is_empty")
299 _can_skip_downstream: bool = attrs.field(alias="can_skip_downstream")
300 _is_sensor: bool = attrs.field(alias="is_sensor", default=False)
301 _task_module: str
302 task_type: str
303 _operator_name: str
304 start_trigger_args: StartTriggerArgs | None
305 start_from_trigger: bool
306 _needs_expansion: bool = True
307
308 dag: DAG | None
309 task_group: TaskGroup | None
310 start_date: pendulum.DateTime | None
311 end_date: pendulum.DateTime | None
312 upstream_task_ids: set[str] = attrs.field(factory=set, init=False)
313 downstream_task_ids: set[str] = attrs.field(factory=set, init=False)
314
315 _disallow_kwargs_override: bool
316 """Whether execution fails if ``expand_input`` has duplicates to ``partial_kwargs``.
317
318 If *False*, values from ``expand_input`` under duplicate keys override those
319 under corresponding keys in ``partial_kwargs``.
320 """
321
322 _expand_input_attr: str
323 """Where to get kwargs to calculate expansion length against.
324
325 This should be a name to call ``getattr()`` on.
326 """
327
328 HIDE_ATTRS_FROM_UI: ClassVar[frozenset[str]] = AbstractOperator.HIDE_ATTRS_FROM_UI | frozenset(
329 ("parse_time_mapped_ti_count", "operator_class", "start_trigger_args", "start_from_trigger")
330 )
331
332 def __hash__(self):
333 return id(self)
334
335 def __repr__(self):
336 return f"<Mapped({self.task_type}): {self.task_id}>"
337
338 def __attrs_post_init__(self):
339 from airflow.sdk.definitions.xcom_arg import XComArg
340
341 if self.get_closest_mapped_task_group() is not None:
342 raise NotImplementedError("operator expansion in an expanded task group is not yet supported")
343
344 if self.task_group:
345 self.task_group.add(self)
346 if self.dag:
347 self.dag.add_task(self)
348 XComArg.apply_upstream_relationship(self, self._get_specified_expand_input().value)
349 for k, v in self.partial_kwargs.items():
350 if k in self.template_fields:
351 XComArg.apply_upstream_relationship(self, v)
352
353 @methodtools.lru_cache(maxsize=None)
354 @classmethod
355 def get_serialized_fields(cls):
356 # Not using 'cls' here since we only want to serialize base fields.
357 return (frozenset(attrs.fields_dict(MappedOperator))) - {
358 "_is_empty",
359 "_can_skip_downstream",
360 "dag",
361 "deps",
362 "expand_input", # This is needed to be able to accept XComArg.
363 "task_group",
364 "upstream_task_ids",
365 "_is_setup",
366 "_is_teardown",
367 "_on_failure_fail_dagrun",
368 "operator_class",
369 "_needs_expansion",
370 "partial_kwargs",
371 "operator_extra_links",
372 }
373
374 @property
375 def operator_name(self) -> str:
376 return self._operator_name
377
378 @property
379 def roots(self) -> Sequence[AbstractOperator]:
380 """Implementing DAGNode."""
381 return [self]
382
383 @property
384 def leaves(self) -> Sequence[AbstractOperator]:
385 """Implementing DAGNode."""
386 return [self]
387
388 @property
389 def task_display_name(self) -> str:
390 return self.partial_kwargs.get("task_display_name") or self.task_id
391
392 @property
393 def owner(self) -> str:
394 return self.partial_kwargs.get("owner", DEFAULT_OWNER)
395
396 @owner.setter
397 def owner(self, value: str) -> None:
398 self.partial_kwargs["owner"] = value
399
400 @property
401 def email(self) -> None | str | Iterable[str]:
402 return self.partial_kwargs.get("email")
403
404 @property
405 def email_on_failure(self) -> bool:
406 return self.partial_kwargs.get("email_on_failure", True)
407
408 @property
409 def email_on_retry(self) -> bool:
410 return self.partial_kwargs.get("email_on_retry", True)
411
412 @property
413 def map_index_template(self) -> None | str:
414 return self.partial_kwargs.get("map_index_template")
415
416 @map_index_template.setter
417 def map_index_template(self, value: str | None) -> None:
418 self.partial_kwargs["map_index_template"] = value
419
420 @property
421 def trigger_rule(self) -> TriggerRule:
422 return self.partial_kwargs.get("trigger_rule", DEFAULT_TRIGGER_RULE)
423
424 @trigger_rule.setter
425 def trigger_rule(self, value):
426 self.partial_kwargs["trigger_rule"] = value
427
428 @property
429 def is_setup(self) -> bool:
430 return bool(self.partial_kwargs.get("is_setup"))
431
432 @is_setup.setter
433 def is_setup(self, value: bool) -> None:
434 self.partial_kwargs["is_setup"] = value
435
436 @property
437 def is_teardown(self) -> bool:
438 return bool(self.partial_kwargs.get("is_teardown"))
439
440 @is_teardown.setter
441 def is_teardown(self, value: bool) -> None:
442 self.partial_kwargs["is_teardown"] = value
443
444 @property
445 def depends_on_past(self) -> bool:
446 return bool(self.partial_kwargs.get("depends_on_past"))
447
448 @depends_on_past.setter
449 def depends_on_past(self, value: bool) -> None:
450 self.partial_kwargs["depends_on_past"] = value
451
452 @property
453 def ignore_first_depends_on_past(self) -> bool:
454 value = self.partial_kwargs.get("ignore_first_depends_on_past", DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST)
455 return bool(value)
456
457 @ignore_first_depends_on_past.setter
458 def ignore_first_depends_on_past(self, value: bool) -> None:
459 self.partial_kwargs["ignore_first_depends_on_past"] = value
460
461 @property
462 def wait_for_past_depends_before_skipping(self) -> bool:
463 value = self.partial_kwargs.get(
464 "wait_for_past_depends_before_skipping", DEFAULT_WAIT_FOR_PAST_DEPENDS_BEFORE_SKIPPING
465 )
466 return bool(value)
467
468 @wait_for_past_depends_before_skipping.setter
469 def wait_for_past_depends_before_skipping(self, value: bool) -> None:
470 self.partial_kwargs["wait_for_past_depends_before_skipping"] = value
471
472 @property
473 def wait_for_downstream(self) -> bool:
474 return bool(self.partial_kwargs.get("wait_for_downstream"))
475
476 @wait_for_downstream.setter
477 def wait_for_downstream(self, value: bool) -> None:
478 self.partial_kwargs["wait_for_downstream"] = value
479
480 @property
481 def retries(self) -> int:
482 return self.partial_kwargs.get("retries", DEFAULT_RETRIES)
483
484 @retries.setter
485 def retries(self, value: int) -> None:
486 self.partial_kwargs["retries"] = value
487
488 @property
489 def queue(self) -> str:
490 return self.partial_kwargs.get("queue", DEFAULT_QUEUE)
491
492 @queue.setter
493 def queue(self, value: str) -> None:
494 self.partial_kwargs["queue"] = value
495
496 @property
497 def pool(self) -> str:
498 return self.partial_kwargs.get("pool", DEFAULT_POOL_NAME)
499
500 @pool.setter
501 def pool(self, value: str) -> None:
502 self.partial_kwargs["pool"] = value
503
504 @property
505 def pool_slots(self) -> int:
506 return self.partial_kwargs.get("pool_slots", DEFAULT_POOL_SLOTS)
507
508 @pool_slots.setter
509 def pool_slots(self, value: int) -> None:
510 self.partial_kwargs["pool_slots"] = value
511
512 @property
513 def execution_timeout(self) -> datetime.timedelta | None:
514 return self.partial_kwargs.get("execution_timeout")
515
516 @execution_timeout.setter
517 def execution_timeout(self, value: datetime.timedelta | None) -> None:
518 self.partial_kwargs["execution_timeout"] = value
519
520 @property
521 def max_retry_delay(self) -> datetime.timedelta | None:
522 return self.partial_kwargs.get("max_retry_delay")
523
524 @max_retry_delay.setter
525 def max_retry_delay(self, value: datetime.timedelta | None) -> None:
526 self.partial_kwargs["max_retry_delay"] = value
527
528 @property
529 def retry_delay(self) -> datetime.timedelta:
530 return self.partial_kwargs.get("retry_delay", DEFAULT_RETRY_DELAY)
531
532 @retry_delay.setter
533 def retry_delay(self, value: datetime.timedelta) -> None:
534 self.partial_kwargs["retry_delay"] = value
535
536 @property
537 def retry_exponential_backoff(self) -> float:
538 value = self.partial_kwargs.get("retry_exponential_backoff", 0)
539 if value is True:
540 return 2.0
541 if value is False:
542 return 0.0
543 return float(value)
544
545 @retry_exponential_backoff.setter
546 def retry_exponential_backoff(self, value: float) -> None:
547 self.partial_kwargs["retry_exponential_backoff"] = value
548
549 @property
550 def priority_weight(self) -> int:
551 return self.partial_kwargs.get("priority_weight", DEFAULT_PRIORITY_WEIGHT)
552
553 @priority_weight.setter
554 def priority_weight(self, value: int) -> None:
555 self.partial_kwargs["priority_weight"] = value
556
557 @property
558 def weight_rule(self) -> PriorityWeightStrategy:
559 return self.partial_kwargs.get("weight_rule", DEFAULT_WEIGHT_RULE)
560
561 @weight_rule.setter
562 def weight_rule(self, value: str | PriorityWeightStrategy) -> None:
563 self.partial_kwargs["weight_rule"] = value
564
565 @property
566 def max_active_tis_per_dag(self) -> int | None:
567 return self.partial_kwargs.get("max_active_tis_per_dag")
568
569 @max_active_tis_per_dag.setter
570 def max_active_tis_per_dag(self, value: int | None) -> None:
571 self.partial_kwargs["max_active_tis_per_dag"] = value
572
573 @property
574 def max_active_tis_per_dagrun(self) -> int | None:
575 return self.partial_kwargs.get("max_active_tis_per_dagrun")
576
577 @max_active_tis_per_dagrun.setter
578 def max_active_tis_per_dagrun(self, value: int | None) -> None:
579 self.partial_kwargs["max_active_tis_per_dagrun"] = value
580
581 @property
582 def resources(self) -> Resources | None:
583 return self.partial_kwargs.get("resources")
584
585 @property
586 def on_execute_callback(self) -> TaskStateChangeCallbackAttrType:
587 return self.partial_kwargs.get("on_execute_callback") or []
588
589 @on_execute_callback.setter
590 def on_execute_callback(self, value: TaskStateChangeCallbackAttrType) -> None:
591 self.partial_kwargs["on_execute_callback"] = value or []
592
593 @property
594 def on_failure_callback(self) -> TaskStateChangeCallbackAttrType:
595 return self.partial_kwargs.get("on_failure_callback") or []
596
597 @on_failure_callback.setter
598 def on_failure_callback(self, value: TaskStateChangeCallbackAttrType) -> None:
599 self.partial_kwargs["on_failure_callback"] = value or []
600
601 @property
602 def on_retry_callback(self) -> TaskStateChangeCallbackAttrType:
603 return self.partial_kwargs.get("on_retry_callback") or []
604
605 @on_retry_callback.setter
606 def on_retry_callback(self, value: TaskStateChangeCallbackAttrType) -> None:
607 self.partial_kwargs["on_retry_callback"] = value or []
608
609 @property
610 def on_success_callback(self) -> TaskStateChangeCallbackAttrType:
611 return self.partial_kwargs.get("on_success_callback") or []
612
613 @on_success_callback.setter
614 def on_success_callback(self, value: TaskStateChangeCallbackAttrType) -> None:
615 self.partial_kwargs["on_success_callback"] = value or []
616
617 @property
618 def on_skipped_callback(self) -> TaskStateChangeCallbackAttrType:
619 return self.partial_kwargs.get("on_skipped_callback") or []
620
621 @on_skipped_callback.setter
622 def on_skipped_callback(self, value: TaskStateChangeCallbackAttrType) -> None:
623 self.partial_kwargs["on_skipped_callback"] = value or []
624
625 @property
626 def has_on_execute_callback(self) -> bool:
627 return bool(self.on_execute_callback)
628
629 @property
630 def has_on_failure_callback(self) -> bool:
631 return bool(self.on_failure_callback)
632
633 @property
634 def has_on_retry_callback(self) -> bool:
635 return bool(self.on_retry_callback)
636
637 @property
638 def has_on_success_callback(self) -> bool:
639 return bool(self.on_success_callback)
640
641 @property
642 def has_on_skipped_callback(self) -> bool:
643 return bool(self.on_skipped_callback)
644
645 @property
646 def run_as_user(self) -> str | None:
647 return self.partial_kwargs.get("run_as_user")
648
649 @property
650 def executor(self) -> str | None:
651 return self.partial_kwargs.get("executor", DEFAULT_EXECUTOR)
652
653 @property
654 def executor_config(self) -> dict:
655 return self.partial_kwargs.get("executor_config", {})
656
657 @property
658 def inlets(self) -> list[Any]:
659 return self.partial_kwargs.get("inlets", [])
660
661 @inlets.setter
662 def inlets(self, value: list[Any]) -> None:
663 self.partial_kwargs["inlets"] = value
664
665 @property
666 def outlets(self) -> list[Any]:
667 return self.partial_kwargs.get("outlets", [])
668
669 @outlets.setter
670 def outlets(self, value: list[Any]) -> None:
671 self.partial_kwargs["outlets"] = value
672
673 @property
674 def doc(self) -> str | None:
675 return self.partial_kwargs.get("doc")
676
677 @property
678 def doc_md(self) -> str | None:
679 return self.partial_kwargs.get("doc_md")
680
681 @property
682 def doc_json(self) -> str | None:
683 return self.partial_kwargs.get("doc_json")
684
685 @property
686 def doc_yaml(self) -> str | None:
687 return self.partial_kwargs.get("doc_yaml")
688
689 @property
690 def doc_rst(self) -> str | None:
691 return self.partial_kwargs.get("doc_rst")
692
693 @property
694 def allow_nested_operators(self) -> bool:
695 return bool(self.partial_kwargs.get("allow_nested_operators"))
696
697 def get_dag(self) -> DAG | None:
698 """Implement Operator."""
699 return self.dag
700
701 @property
702 def output(self) -> XComArg:
703 """Return reference to XCom pushed by current operator."""
704 from airflow.sdk.definitions.xcom_arg import XComArg
705
706 return XComArg(operator=self)
707
708 def serialize_for_task_group(self) -> tuple[DagAttributeTypes, Any]:
709 """Implement DAGNode."""
710 return DagAttributeTypes.OP, self.task_id
711
712 def _expand_mapped_kwargs(self, context: Mapping[str, Any]) -> tuple[Mapping[str, Any], set[int]]:
713 """
714 Get the kwargs to create the unmapped operator.
715
716 This exists because taskflow operators expand against op_kwargs, not the
717 entire operator kwargs dict.
718 """
719 return self._get_specified_expand_input().resolve(context)
720
721 def _get_unmap_kwargs(self, mapped_kwargs: Mapping[str, Any], *, strict: bool) -> dict[str, Any]:
722 """
723 Get init kwargs to unmap the underlying operator class.
724
725 :param mapped_kwargs: The dict returned by ``_expand_mapped_kwargs``.
726 """
727 if strict:
728 prevent_duplicates(
729 self.partial_kwargs,
730 mapped_kwargs,
731 fail_reason="unmappable or already specified",
732 )
733
734 # If params appears in the mapped kwargs, we need to merge it into the
735 # partial params, overriding existing keys.
736 params = copy.copy(self.params)
737 with contextlib.suppress(KeyError):
738 params.update(mapped_kwargs["params"])
739
740 # Ordering is significant; mapped kwargs should override partial ones,
741 # and the specially handled params should be respected.
742 return {
743 "task_id": self.task_id,
744 "dag": self.dag,
745 "task_group": self.task_group,
746 "start_date": self.start_date,
747 "end_date": self.end_date,
748 **self.partial_kwargs,
749 **mapped_kwargs,
750 "params": params,
751 }
752
753 def unmap(self, resolve: Mapping[str, Any]) -> BaseOperator:
754 """
755 Get the "normal" Operator after applying the current mapping.
756
757 :meta private:
758 """
759 kwargs = self._get_unmap_kwargs(resolve, strict=self._disallow_kwargs_override)
760 is_setup = kwargs.pop("is_setup", False)
761 is_teardown = kwargs.pop("is_teardown", False)
762 on_failure_fail_dagrun = kwargs.pop("on_failure_fail_dagrun", False)
763 kwargs["task_id"] = self.task_id
764 op = self.operator_class(**kwargs, _airflow_from_mapped=True)
765 op.is_setup = is_setup
766 op.is_teardown = is_teardown
767 op.on_failure_fail_dagrun = on_failure_fail_dagrun
768 op.downstream_task_ids = self.downstream_task_ids
769 op.upstream_task_ids = self.upstream_task_ids
770 return op
771
772 def _get_specified_expand_input(self) -> ExpandInput:
773 """Input received from the expand call on the operator."""
774 return getattr(self, self._expand_input_attr)
775
776 def prepare_for_execution(self) -> MappedOperator:
777 # Since a mapped operator cannot be used for execution, and an unmapped
778 # BaseOperator needs to be created later (see render_template_fields),
779 # we don't need to create a copy of the MappedOperator here.
780 return self
781
782 # TODO (GH-52141): Do we need this in the SDK?
783 def iter_mapped_dependencies(self) -> Iterator[AbstractOperator]:
784 """Upstream dependencies that provide XComs used by this task for task mapping."""
785 from airflow.sdk.definitions.xcom_arg import XComArg
786
787 for operator, _ in XComArg.iter_xcom_references(self._get_specified_expand_input()):
788 yield operator
789
790 def render_template_fields(
791 self,
792 context: Context,
793 jinja_env: jinja2.Environment | None = None,
794 ) -> None:
795 """
796 Template all attributes listed in *self.template_fields*.
797
798 This updates *context* to reference the map-expanded task and relevant
799 information, without modifying the mapped operator. The expanded task
800 in *context* is then rendered in-place.
801
802 :param context: Context dict with values to apply on content.
803 :param jinja_env: Jinja environment to use for rendering.
804 """
805 from airflow.sdk.execution_time.context import context_update_for_unmapped
806
807 if not jinja_env:
808 jinja_env = self.get_template_env()
809
810 mapped_kwargs, seen_oids = self._expand_mapped_kwargs(context)
811 unmapped_task = self.unmap(mapped_kwargs)
812 context_update_for_unmapped(context, unmapped_task)
813
814 # Since the operators that extend `BaseOperator` are not subclasses of
815 # `MappedOperator`, we need to call `_do_render_template_fields` from
816 # the unmapped task in order to call the operator method when we override
817 # it to customize the parsing of nested fields.
818 unmapped_task._do_render_template_fields(
819 parent=unmapped_task,
820 template_fields=self.template_fields,
821 context=context,
822 jinja_env=jinja_env,
823 seen_oids=seen_oids,
824 )
825
826 def expand_start_trigger_args(self, *, context: Context) -> StartTriggerArgs | None:
827 """
828 Get the kwargs to create the unmapped start_trigger_args.
829
830 This method is for allowing mapped operator to start execution from triggerer.
831 """
832 from airflow.triggers.base import StartTriggerArgs
833
834 if not self.start_trigger_args:
835 return None
836
837 mapped_kwargs, _ = self._expand_mapped_kwargs(context)
838 if self._disallow_kwargs_override:
839 prevent_duplicates(
840 self.partial_kwargs,
841 mapped_kwargs,
842 fail_reason="unmappable or already specified",
843 )
844
845 # Ordering is significant; mapped kwargs should override partial ones.
846 trigger_kwargs = mapped_kwargs.get(
847 "trigger_kwargs",
848 self.partial_kwargs.get("trigger_kwargs", self.start_trigger_args.trigger_kwargs),
849 )
850 next_kwargs = mapped_kwargs.get(
851 "next_kwargs",
852 self.partial_kwargs.get("next_kwargs", self.start_trigger_args.next_kwargs),
853 )
854 timeout = mapped_kwargs.get(
855 "trigger_timeout", self.partial_kwargs.get("trigger_timeout", self.start_trigger_args.timeout)
856 )
857 return StartTriggerArgs(
858 trigger_cls=self.start_trigger_args.trigger_cls,
859 trigger_kwargs=trigger_kwargs,
860 next_method=self.start_trigger_args.next_method,
861 next_kwargs=next_kwargs,
862 timeout=timeout,
863 )