1#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18from __future__ import annotations
19
20import contextlib
21import copy
22import warnings
23from collections.abc import Collection, Iterable, Iterator, Mapping, Sequence
24from typing import TYPE_CHECKING, Any, ClassVar, Literal, TypeGuard
25
26import attrs
27import methodtools
28from lazy_object_proxy import Proxy
29
30from airflow.sdk.bases.xcom import BaseXCom
31from airflow.sdk.definitions._internal.abstractoperator import (
32 DEFAULT_EXECUTOR,
33 DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST,
34 DEFAULT_OWNER,
35 DEFAULT_POOL_NAME,
36 DEFAULT_POOL_SLOTS,
37 DEFAULT_PRIORITY_WEIGHT,
38 DEFAULT_QUEUE,
39 DEFAULT_RETRIES,
40 DEFAULT_RETRY_DELAY,
41 DEFAULT_TRIGGER_RULE,
42 DEFAULT_WAIT_FOR_PAST_DEPENDS_BEFORE_SKIPPING,
43 DEFAULT_WEIGHT_RULE,
44 AbstractOperator,
45 TaskStateChangeCallbackAttrType,
46)
47from airflow.sdk.definitions._internal.expandinput import (
48 DictOfListsExpandInput,
49 ListOfDictsExpandInput,
50 is_mappable,
51)
52from airflow.sdk.definitions._internal.types import NOTSET
53from airflow.serialization.enums import DagAttributeTypes
54from airflow.task.priority_strategy import PriorityWeightStrategy, validate_and_load_priority_weight_strategy
55
56if TYPE_CHECKING:
57 import datetime
58
59 import jinja2 # Slow import.
60 import pendulum
61
62 from airflow.sdk import DAG, BaseOperator, BaseOperatorLink, Context, TaskGroup, TriggerRule, XComArg
63 from airflow.sdk.definitions._internal.expandinput import (
64 ExpandInput,
65 OperatorExpandArgument,
66 OperatorExpandKwargsArgument,
67 )
68 from airflow.sdk.definitions.operator_resources import Resources
69 from airflow.sdk.definitions.param import ParamsDict
70 from airflow.triggers.base import StartTriggerArgs
71
72ValidationSource = Literal["expand"] | Literal["partial"]
73
74
75def validate_mapping_kwargs(op: type[BaseOperator], func: ValidationSource, value: dict[str, Any]) -> None:
76 # use a dict so order of args is same as code order
77 unknown_args = value.copy()
78 for klass in op.mro():
79 init = klass.__init__ # type: ignore[misc]
80 try:
81 param_names = init._BaseOperatorMeta__param_names
82 except AttributeError:
83 continue
84 for name in param_names:
85 value = unknown_args.pop(name, NOTSET)
86 if func != "expand":
87 continue
88 if value is NOTSET:
89 continue
90 if is_mappable(value):
91 continue
92 type_name = type(value).__name__
93 error = f"{op.__name__}.expand() got an unexpected type {type_name!r} for keyword argument {name}"
94 raise ValueError(error)
95 if not unknown_args:
96 return # If we have no args left to check: stop looking at the MRO chain.
97
98 if len(unknown_args) == 1:
99 error = f"an unexpected keyword argument {unknown_args.popitem()[0]!r}"
100 else:
101 names = ", ".join(repr(n) for n in unknown_args)
102 error = f"unexpected keyword arguments {names}"
103 raise TypeError(f"{op.__name__}.{func}() got {error}")
104
105
106def _is_container(obj: Any) -> bool:
107 """Test if an object is a container (iterable) but not a string."""
108 if isinstance(obj, Proxy):
109 # Proxy of any object is considered a container because it implements __iter__
110 # to forward the call to the lazily initialized object
111 # Unwrap Proxy before checking __iter__ to evaluate the proxied object
112 obj = obj.__wrapped__
113 return hasattr(obj, "__iter__") and not isinstance(obj, str)
114
115
116def ensure_xcomarg_return_value(arg: Any) -> None:
117 from airflow.sdk.definitions.xcom_arg import XComArg
118
119 if isinstance(arg, XComArg):
120 for operator, key in arg.iter_references():
121 if key != BaseXCom.XCOM_RETURN_KEY:
122 raise ValueError(f"cannot map over XCom with custom key {key!r} from {operator}")
123 elif not _is_container(arg):
124 return
125 elif isinstance(arg, Mapping):
126 for v in arg.values():
127 ensure_xcomarg_return_value(v)
128 elif isinstance(arg, Iterable):
129 for v in arg:
130 ensure_xcomarg_return_value(v)
131
132
133def is_mappable_value(value: Any) -> TypeGuard[Collection]:
134 """
135 Whether a value can be used for task mapping.
136
137 We only allow collections with guaranteed ordering, but exclude character
138 sequences since that's usually not what users would expect to be mappable.
139
140 :meta private:
141 """
142 if not isinstance(value, (Sequence, dict)):
143 return False
144 if isinstance(value, (bytearray, bytes, str)):
145 return False
146 return True
147
148
149def prevent_duplicates(kwargs1: dict[str, Any], kwargs2: Mapping[str, Any], *, fail_reason: str) -> None:
150 """
151 Ensure *kwargs1* and *kwargs2* do not contain common keys.
152
153 :raises TypeError: If common keys are found.
154 """
155 duplicated_keys = set(kwargs1).intersection(kwargs2)
156 if not duplicated_keys:
157 return
158 if len(duplicated_keys) == 1:
159 raise TypeError(f"{fail_reason} argument: {duplicated_keys.pop()}")
160 duplicated_keys_display = ", ".join(sorted(duplicated_keys))
161 raise TypeError(f"{fail_reason} arguments: {duplicated_keys_display}")
162
163
164@attrs.define(kw_only=True, repr=False)
165class OperatorPartial:
166 """
167 An "intermediate state" returned by ``BaseOperator.partial()``.
168
169 This only exists at Dag-parsing time; the only intended usage is for the
170 user to call ``.expand()`` on it at some point (usually in a method chain) to
171 create a ``MappedOperator`` to add into the Dag.
172 """
173
174 operator_class: type[BaseOperator]
175 kwargs: dict[str, Any]
176 params: ParamsDict | dict
177
178 _expand_called: bool = False # Set when expand() is called to ease user debugging.
179
180 def __attrs_post_init__(self):
181 validate_mapping_kwargs(self.operator_class, "partial", self.kwargs)
182
183 def __repr__(self) -> str:
184 args = ", ".join(f"{k}={v!r}" for k, v in self.kwargs.items())
185 return f"{self.operator_class.__name__}.partial({args})"
186
187 def __del__(self):
188 if not self._expand_called:
189 try:
190 task_id = repr(self.kwargs["task_id"])
191 except KeyError:
192 task_id = f"at {hex(id(self))}"
193 warnings.warn(f"Task {task_id} was never mapped!", category=UserWarning, stacklevel=1)
194
195 def expand(self, **mapped_kwargs: OperatorExpandArgument) -> MappedOperator:
196 if not mapped_kwargs:
197 raise TypeError("no arguments to expand against")
198 validate_mapping_kwargs(self.operator_class, "expand", mapped_kwargs)
199 prevent_duplicates(self.kwargs, mapped_kwargs, fail_reason="unmappable or already specified")
200 # Since the input is already checked at parse time, we can set strict
201 # to False to skip the checks on execution.
202 return self._expand(DictOfListsExpandInput(mapped_kwargs), strict=False)
203
204 def expand_kwargs(self, kwargs: OperatorExpandKwargsArgument, *, strict: bool = True) -> MappedOperator:
205 from airflow.sdk.definitions.xcom_arg import XComArg
206
207 if isinstance(kwargs, Sequence):
208 for item in kwargs:
209 if not isinstance(item, (XComArg, Mapping)):
210 raise TypeError(f"expected XComArg or list[dict], not {type(kwargs).__name__}")
211 elif not isinstance(kwargs, XComArg):
212 raise TypeError(f"expected XComArg or list[dict], not {type(kwargs).__name__}")
213 return self._expand(ListOfDictsExpandInput(kwargs), strict=strict)
214
215 def _expand(self, expand_input: ExpandInput, *, strict: bool) -> MappedOperator:
216 from airflow.providers.standard.operators.empty import EmptyOperator
217 from airflow.providers.standard.utils.skipmixin import SkipMixin
218 from airflow.sdk import BaseSensorOperator
219
220 self._expand_called = True
221 ensure_xcomarg_return_value(expand_input.value)
222
223 partial_kwargs = self.kwargs.copy()
224 task_id = partial_kwargs.pop("task_id")
225 dag = partial_kwargs.pop("dag")
226 task_group = partial_kwargs.pop("task_group")
227 start_date = partial_kwargs.pop("start_date", None)
228 end_date = partial_kwargs.pop("end_date", None)
229
230 try:
231 operator_name = self.operator_class.custom_operator_name # type: ignore
232 except AttributeError:
233 operator_name = self.operator_class.__name__
234
235 op = MappedOperator(
236 operator_class=self.operator_class,
237 expand_input=expand_input,
238 partial_kwargs=partial_kwargs,
239 task_id=task_id,
240 params=self.params,
241 operator_extra_links=self.operator_class.operator_extra_links,
242 template_ext=self.operator_class.template_ext,
243 template_fields=self.operator_class.template_fields,
244 template_fields_renderers=self.operator_class.template_fields_renderers,
245 ui_color=self.operator_class.ui_color,
246 ui_fgcolor=self.operator_class.ui_fgcolor,
247 is_empty=issubclass(self.operator_class, EmptyOperator),
248 is_sensor=issubclass(self.operator_class, BaseSensorOperator),
249 can_skip_downstream=issubclass(self.operator_class, SkipMixin),
250 task_module=self.operator_class.__module__,
251 task_type=self.operator_class.__name__,
252 operator_name=operator_name,
253 dag=dag,
254 task_group=task_group,
255 start_date=start_date,
256 end_date=end_date,
257 disallow_kwargs_override=strict,
258 # For classic operators, this points to expand_input because kwargs
259 # to BaseOperator.expand() contribute to operator arguments.
260 expand_input_attr="expand_input",
261 # TODO: Move these to task SDK's BaseOperator and remove getattr
262 start_trigger_args=getattr(self.operator_class, "start_trigger_args", None),
263 start_from_trigger=bool(getattr(self.operator_class, "start_from_trigger", False)),
264 )
265 return op
266
267
268@attrs.define(
269 kw_only=True,
270 # Disable custom __getstate__ and __setstate__ generation since it interacts
271 # badly with Airflow's Dag serialization and pickling. When a mapped task is
272 # deserialized, subclasses are coerced into MappedOperator, but when it goes
273 # through Dag pickling, all attributes defined in the subclasses are dropped
274 # by attrs's custom state management. Since attrs does not do anything too
275 # special here (the logic is only important for slots=True), we use Python's
276 # built-in implementation, which works (as proven by good old BaseOperator).
277 getstate_setstate=False,
278)
279class MappedOperator(AbstractOperator):
280 """Object representing a mapped operator in a Dag."""
281
282 operator_class: type[BaseOperator]
283
284 _is_mapped: bool = attrs.field(init=False, default=True)
285
286 expand_input: ExpandInput
287 partial_kwargs: dict[str, Any]
288
289 # Needed for serialization.
290 task_id: str
291 params: ParamsDict | dict
292 operator_extra_links: Collection[BaseOperatorLink]
293 template_ext: Sequence[str]
294 template_fields: Collection[str]
295 template_fields_renderers: dict[str, str]
296 ui_color: str
297 ui_fgcolor: str
298 _is_empty: bool = attrs.field(alias="is_empty")
299 _can_skip_downstream: bool = attrs.field(alias="can_skip_downstream")
300 _is_sensor: bool = attrs.field(alias="is_sensor", default=False)
301 _task_module: str
302 task_type: str
303 _operator_name: str
304 start_trigger_args: StartTriggerArgs | None
305 start_from_trigger: bool
306 _needs_expansion: bool = True
307
308 dag: DAG | None
309 task_group: TaskGroup | None
310 start_date: pendulum.DateTime | None
311 end_date: pendulum.DateTime | None
312 upstream_task_ids: set[str] = attrs.field(factory=set, init=False)
313 downstream_task_ids: set[str] = attrs.field(factory=set, init=False)
314
315 _disallow_kwargs_override: bool
316 """Whether execution fails if ``expand_input`` has duplicates to ``partial_kwargs``.
317
318 If *False*, values from ``expand_input`` under duplicate keys override those
319 under corresponding keys in ``partial_kwargs``.
320 """
321
322 _expand_input_attr: str
323 """Where to get kwargs to calculate expansion length against.
324
325 This should be a name to call ``getattr()`` on.
326 """
327
328 HIDE_ATTRS_FROM_UI: ClassVar[frozenset[str]] = AbstractOperator.HIDE_ATTRS_FROM_UI | frozenset(
329 ("parse_time_mapped_ti_count", "operator_class", "start_trigger_args", "start_from_trigger")
330 )
331
332 def __hash__(self):
333 return id(self)
334
335 def __repr__(self):
336 return f"<Mapped({self.task_type}): {self.task_id}>"
337
338 def __attrs_post_init__(self):
339 from airflow.sdk.definitions.xcom_arg import XComArg
340
341 if self.get_closest_mapped_task_group() is not None:
342 raise NotImplementedError("operator expansion in an expanded task group is not yet supported")
343
344 if self.task_group:
345 self.task_group.add(self)
346 if self.dag:
347 self.dag.add_task(self)
348 XComArg.apply_upstream_relationship(self, self._get_specified_expand_input().value)
349 for k, v in self.partial_kwargs.items():
350 if k in self.template_fields:
351 XComArg.apply_upstream_relationship(self, v)
352
353 @methodtools.lru_cache(maxsize=None)
354 @classmethod
355 def get_serialized_fields(cls):
356 # Not using 'cls' here since we only want to serialize base fields.
357 return (frozenset(attrs.fields_dict(MappedOperator))) - {
358 "_is_empty",
359 "_can_skip_downstream",
360 "dag",
361 "deps",
362 "expand_input", # This is needed to be able to accept XComArg.
363 "task_group",
364 "upstream_task_ids",
365 "_is_setup",
366 "_is_teardown",
367 "_on_failure_fail_dagrun",
368 "operator_class",
369 "_needs_expansion",
370 "partial_kwargs",
371 "operator_extra_links",
372 }
373
374 @property
375 def operator_name(self) -> str:
376 return self._operator_name
377
378 @property
379 def roots(self) -> Sequence[AbstractOperator]:
380 """Implementing DAGNode."""
381 return [self]
382
383 @property
384 def leaves(self) -> Sequence[AbstractOperator]:
385 """Implementing DAGNode."""
386 return [self]
387
388 @property
389 def task_display_name(self) -> str:
390 return self.partial_kwargs.get("task_display_name") or self.task_id
391
392 @property
393 def owner(self) -> str:
394 return self.partial_kwargs.get("owner", DEFAULT_OWNER)
395
396 @owner.setter
397 def owner(self, value: str) -> None:
398 self.partial_kwargs["owner"] = value
399
400 @property
401 def email(self) -> None | str | Iterable[str]:
402 return self.partial_kwargs.get("email")
403
404 @property
405 def email_on_failure(self) -> bool:
406 return self.partial_kwargs.get("email_on_failure", True)
407
408 @property
409 def email_on_retry(self) -> bool:
410 return self.partial_kwargs.get("email_on_retry", True)
411
412 @property
413 def map_index_template(self) -> None | str:
414 return self.partial_kwargs.get("map_index_template")
415
416 @map_index_template.setter
417 def map_index_template(self, value: str | None) -> None:
418 self.partial_kwargs["map_index_template"] = value
419
420 @property
421 def trigger_rule(self) -> TriggerRule:
422 return self.partial_kwargs.get("trigger_rule", DEFAULT_TRIGGER_RULE)
423
424 @trigger_rule.setter
425 def trigger_rule(self, value):
426 self.partial_kwargs["trigger_rule"] = value
427
428 @property
429 def is_setup(self) -> bool:
430 return bool(self.partial_kwargs.get("is_setup"))
431
432 @is_setup.setter
433 def is_setup(self, value: bool) -> None:
434 self.partial_kwargs["is_setup"] = value
435
436 @property
437 def is_teardown(self) -> bool:
438 return bool(self.partial_kwargs.get("is_teardown"))
439
440 @is_teardown.setter
441 def is_teardown(self, value: bool) -> None:
442 self.partial_kwargs["is_teardown"] = value
443
444 @property
445 def depends_on_past(self) -> bool:
446 return bool(self.partial_kwargs.get("depends_on_past"))
447
448 @depends_on_past.setter
449 def depends_on_past(self, value: bool) -> None:
450 self.partial_kwargs["depends_on_past"] = value
451
452 @property
453 def ignore_first_depends_on_past(self) -> bool:
454 value = self.partial_kwargs.get("ignore_first_depends_on_past", DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST)
455 return bool(value)
456
457 @ignore_first_depends_on_past.setter
458 def ignore_first_depends_on_past(self, value: bool) -> None:
459 self.partial_kwargs["ignore_first_depends_on_past"] = value
460
461 @property
462 def wait_for_past_depends_before_skipping(self) -> bool:
463 value = self.partial_kwargs.get(
464 "wait_for_past_depends_before_skipping", DEFAULT_WAIT_FOR_PAST_DEPENDS_BEFORE_SKIPPING
465 )
466 return bool(value)
467
468 @wait_for_past_depends_before_skipping.setter
469 def wait_for_past_depends_before_skipping(self, value: bool) -> None:
470 self.partial_kwargs["wait_for_past_depends_before_skipping"] = value
471
472 @property
473 def wait_for_downstream(self) -> bool:
474 return bool(self.partial_kwargs.get("wait_for_downstream"))
475
476 @wait_for_downstream.setter
477 def wait_for_downstream(self, value: bool) -> None:
478 self.partial_kwargs["wait_for_downstream"] = value
479
480 @property
481 def retries(self) -> int:
482 return self.partial_kwargs.get("retries", DEFAULT_RETRIES)
483
484 @retries.setter
485 def retries(self, value: int) -> None:
486 self.partial_kwargs["retries"] = value
487
488 @property
489 def queue(self) -> str:
490 return self.partial_kwargs.get("queue", DEFAULT_QUEUE)
491
492 @queue.setter
493 def queue(self, value: str) -> None:
494 self.partial_kwargs["queue"] = value
495
496 @property
497 def pool(self) -> str:
498 return self.partial_kwargs.get("pool", DEFAULT_POOL_NAME)
499
500 @pool.setter
501 def pool(self, value: str) -> None:
502 self.partial_kwargs["pool"] = value
503
504 @property
505 def pool_slots(self) -> int:
506 return self.partial_kwargs.get("pool_slots", DEFAULT_POOL_SLOTS)
507
508 @pool_slots.setter
509 def pool_slots(self, value: int) -> None:
510 self.partial_kwargs["pool_slots"] = value
511
512 @property
513 def execution_timeout(self) -> datetime.timedelta | None:
514 return self.partial_kwargs.get("execution_timeout")
515
516 @execution_timeout.setter
517 def execution_timeout(self, value: datetime.timedelta | None) -> None:
518 self.partial_kwargs["execution_timeout"] = value
519
520 @property
521 def max_retry_delay(self) -> datetime.timedelta | None:
522 return self.partial_kwargs.get("max_retry_delay")
523
524 @max_retry_delay.setter
525 def max_retry_delay(self, value: datetime.timedelta | None) -> None:
526 self.partial_kwargs["max_retry_delay"] = value
527
528 @property
529 def retry_delay(self) -> datetime.timedelta:
530 return self.partial_kwargs.get("retry_delay", DEFAULT_RETRY_DELAY)
531
532 @retry_delay.setter
533 def retry_delay(self, value: datetime.timedelta) -> None:
534 self.partial_kwargs["retry_delay"] = value
535
536 @property
537 def retry_exponential_backoff(self) -> float:
538 value = self.partial_kwargs.get("retry_exponential_backoff", 0)
539 if value is True:
540 return 2.0
541 if value is False:
542 return 0.0
543 return float(value)
544
545 @retry_exponential_backoff.setter
546 def retry_exponential_backoff(self, value: float) -> None:
547 self.partial_kwargs["retry_exponential_backoff"] = value
548
549 @property
550 def priority_weight(self) -> int:
551 return self.partial_kwargs.get("priority_weight", DEFAULT_PRIORITY_WEIGHT)
552
553 @priority_weight.setter
554 def priority_weight(self, value: int) -> None:
555 self.partial_kwargs["priority_weight"] = value
556
557 @property
558 def weight_rule(self) -> PriorityWeightStrategy:
559 return validate_and_load_priority_weight_strategy(
560 self.partial_kwargs.get("weight_rule", DEFAULT_WEIGHT_RULE)
561 )
562
563 @weight_rule.setter
564 def weight_rule(self, value: str | PriorityWeightStrategy) -> None:
565 self.partial_kwargs["weight_rule"] = validate_and_load_priority_weight_strategy(value)
566
567 @property
568 def max_active_tis_per_dag(self) -> int | None:
569 return self.partial_kwargs.get("max_active_tis_per_dag")
570
571 @max_active_tis_per_dag.setter
572 def max_active_tis_per_dag(self, value: int | None) -> None:
573 self.partial_kwargs["max_active_tis_per_dag"] = value
574
575 @property
576 def max_active_tis_per_dagrun(self) -> int | None:
577 return self.partial_kwargs.get("max_active_tis_per_dagrun")
578
579 @max_active_tis_per_dagrun.setter
580 def max_active_tis_per_dagrun(self, value: int | None) -> None:
581 self.partial_kwargs["max_active_tis_per_dagrun"] = value
582
583 @property
584 def resources(self) -> Resources | None:
585 return self.partial_kwargs.get("resources")
586
587 @property
588 def on_execute_callback(self) -> TaskStateChangeCallbackAttrType:
589 return self.partial_kwargs.get("on_execute_callback") or []
590
591 @on_execute_callback.setter
592 def on_execute_callback(self, value: TaskStateChangeCallbackAttrType) -> None:
593 self.partial_kwargs["on_execute_callback"] = value or []
594
595 @property
596 def on_failure_callback(self) -> TaskStateChangeCallbackAttrType:
597 return self.partial_kwargs.get("on_failure_callback") or []
598
599 @on_failure_callback.setter
600 def on_failure_callback(self, value: TaskStateChangeCallbackAttrType) -> None:
601 self.partial_kwargs["on_failure_callback"] = value or []
602
603 @property
604 def on_retry_callback(self) -> TaskStateChangeCallbackAttrType:
605 return self.partial_kwargs.get("on_retry_callback") or []
606
607 @on_retry_callback.setter
608 def on_retry_callback(self, value: TaskStateChangeCallbackAttrType) -> None:
609 self.partial_kwargs["on_retry_callback"] = value or []
610
611 @property
612 def on_success_callback(self) -> TaskStateChangeCallbackAttrType:
613 return self.partial_kwargs.get("on_success_callback") or []
614
615 @on_success_callback.setter
616 def on_success_callback(self, value: TaskStateChangeCallbackAttrType) -> None:
617 self.partial_kwargs["on_success_callback"] = value or []
618
619 @property
620 def on_skipped_callback(self) -> TaskStateChangeCallbackAttrType:
621 return self.partial_kwargs.get("on_skipped_callback") or []
622
623 @on_skipped_callback.setter
624 def on_skipped_callback(self, value: TaskStateChangeCallbackAttrType) -> None:
625 self.partial_kwargs["on_skipped_callback"] = value or []
626
627 @property
628 def has_on_execute_callback(self) -> bool:
629 return bool(self.on_execute_callback)
630
631 @property
632 def has_on_failure_callback(self) -> bool:
633 return bool(self.on_failure_callback)
634
635 @property
636 def has_on_retry_callback(self) -> bool:
637 return bool(self.on_retry_callback)
638
639 @property
640 def has_on_success_callback(self) -> bool:
641 return bool(self.on_success_callback)
642
643 @property
644 def has_on_skipped_callback(self) -> bool:
645 return bool(self.on_skipped_callback)
646
647 @property
648 def run_as_user(self) -> str | None:
649 return self.partial_kwargs.get("run_as_user")
650
651 @property
652 def executor(self) -> str | None:
653 return self.partial_kwargs.get("executor", DEFAULT_EXECUTOR)
654
655 @property
656 def executor_config(self) -> dict:
657 return self.partial_kwargs.get("executor_config", {})
658
659 @property
660 def inlets(self) -> list[Any]:
661 return self.partial_kwargs.get("inlets", [])
662
663 @inlets.setter
664 def inlets(self, value: list[Any]) -> None:
665 self.partial_kwargs["inlets"] = value
666
667 @property
668 def outlets(self) -> list[Any]:
669 return self.partial_kwargs.get("outlets", [])
670
671 @outlets.setter
672 def outlets(self, value: list[Any]) -> None:
673 self.partial_kwargs["outlets"] = value
674
675 @property
676 def doc(self) -> str | None:
677 return self.partial_kwargs.get("doc")
678
679 @property
680 def doc_md(self) -> str | None:
681 return self.partial_kwargs.get("doc_md")
682
683 @property
684 def doc_json(self) -> str | None:
685 return self.partial_kwargs.get("doc_json")
686
687 @property
688 def doc_yaml(self) -> str | None:
689 return self.partial_kwargs.get("doc_yaml")
690
691 @property
692 def doc_rst(self) -> str | None:
693 return self.partial_kwargs.get("doc_rst")
694
695 @property
696 def allow_nested_operators(self) -> bool:
697 return bool(self.partial_kwargs.get("allow_nested_operators"))
698
699 def get_dag(self) -> DAG | None:
700 """Implement Operator."""
701 return self.dag
702
703 @property
704 def output(self) -> XComArg:
705 """Return reference to XCom pushed by current operator."""
706 from airflow.sdk.definitions.xcom_arg import XComArg
707
708 return XComArg(operator=self)
709
710 def serialize_for_task_group(self) -> tuple[DagAttributeTypes, Any]:
711 """Implement DAGNode."""
712 return DagAttributeTypes.OP, self.task_id
713
714 def _expand_mapped_kwargs(self, context: Mapping[str, Any]) -> tuple[Mapping[str, Any], set[int]]:
715 """
716 Get the kwargs to create the unmapped operator.
717
718 This exists because taskflow operators expand against op_kwargs, not the
719 entire operator kwargs dict.
720 """
721 return self._get_specified_expand_input().resolve(context)
722
723 def _get_unmap_kwargs(self, mapped_kwargs: Mapping[str, Any], *, strict: bool) -> dict[str, Any]:
724 """
725 Get init kwargs to unmap the underlying operator class.
726
727 :param mapped_kwargs: The dict returned by ``_expand_mapped_kwargs``.
728 """
729 if strict:
730 prevent_duplicates(
731 self.partial_kwargs,
732 mapped_kwargs,
733 fail_reason="unmappable or already specified",
734 )
735
736 # If params appears in the mapped kwargs, we need to merge it into the
737 # partial params, overriding existing keys.
738 params = copy.copy(self.params)
739 with contextlib.suppress(KeyError):
740 params.update(mapped_kwargs["params"])
741
742 # Ordering is significant; mapped kwargs should override partial ones,
743 # and the specially handled params should be respected.
744 return {
745 "task_id": self.task_id,
746 "dag": self.dag,
747 "task_group": self.task_group,
748 "start_date": self.start_date,
749 "end_date": self.end_date,
750 **self.partial_kwargs,
751 **mapped_kwargs,
752 "params": params,
753 }
754
755 def unmap(self, resolve: None | Mapping[str, Any]) -> BaseOperator:
756 """
757 Get the "normal" Operator after applying the current mapping.
758
759 :meta private:
760 """
761 if isinstance(resolve, Mapping):
762 kwargs = resolve
763 elif resolve is not None:
764 kwargs, _ = self._expand_mapped_kwargs(*resolve)
765 else:
766 raise RuntimeError("cannot unmap a non-serialized operator without context")
767 kwargs = self._get_unmap_kwargs(kwargs, strict=self._disallow_kwargs_override)
768 is_setup = kwargs.pop("is_setup", False)
769 is_teardown = kwargs.pop("is_teardown", False)
770 on_failure_fail_dagrun = kwargs.pop("on_failure_fail_dagrun", False)
771 kwargs["task_id"] = self.task_id
772 op = self.operator_class(**kwargs, _airflow_from_mapped=True)
773 op.is_setup = is_setup
774 op.is_teardown = is_teardown
775 op.on_failure_fail_dagrun = on_failure_fail_dagrun
776 op.downstream_task_ids = self.downstream_task_ids
777 op.upstream_task_ids = self.upstream_task_ids
778 return op
779
780 def _get_specified_expand_input(self) -> ExpandInput:
781 """Input received from the expand call on the operator."""
782 return getattr(self, self._expand_input_attr)
783
784 def prepare_for_execution(self) -> MappedOperator:
785 # Since a mapped operator cannot be used for execution, and an unmapped
786 # BaseOperator needs to be created later (see render_template_fields),
787 # we don't need to create a copy of the MappedOperator here.
788 return self
789
790 # TODO (GH-52141): Do we need this in the SDK?
791 def iter_mapped_dependencies(self) -> Iterator[AbstractOperator]:
792 """Upstream dependencies that provide XComs used by this task for task mapping."""
793 from airflow.sdk.definitions.xcom_arg import XComArg
794
795 for operator, _ in XComArg.iter_xcom_references(self._get_specified_expand_input()):
796 yield operator
797
798 def render_template_fields(
799 self,
800 context: Context,
801 jinja_env: jinja2.Environment | None = None,
802 ) -> None:
803 """
804 Template all attributes listed in *self.template_fields*.
805
806 This updates *context* to reference the map-expanded task and relevant
807 information, without modifying the mapped operator. The expanded task
808 in *context* is then rendered in-place.
809
810 :param context: Context dict with values to apply on content.
811 :param jinja_env: Jinja environment to use for rendering.
812 """
813 from airflow.sdk.execution_time.context import context_update_for_unmapped
814
815 if not jinja_env:
816 jinja_env = self.get_template_env()
817
818 mapped_kwargs, seen_oids = self._expand_mapped_kwargs(context)
819 unmapped_task = self.unmap(mapped_kwargs)
820 context_update_for_unmapped(context, unmapped_task)
821
822 # Since the operators that extend `BaseOperator` are not subclasses of
823 # `MappedOperator`, we need to call `_do_render_template_fields` from
824 # the unmapped task in order to call the operator method when we override
825 # it to customize the parsing of nested fields.
826 unmapped_task._do_render_template_fields(
827 parent=unmapped_task,
828 template_fields=self.template_fields,
829 context=context,
830 jinja_env=jinja_env,
831 seen_oids=seen_oids,
832 )
833
834 def expand_start_trigger_args(self, *, context: Context) -> StartTriggerArgs | None:
835 """
836 Get the kwargs to create the unmapped start_trigger_args.
837
838 This method is for allowing mapped operator to start execution from triggerer.
839 """
840 from airflow.triggers.base import StartTriggerArgs
841
842 if not self.start_trigger_args:
843 return None
844
845 mapped_kwargs, _ = self._expand_mapped_kwargs(context)
846 if self._disallow_kwargs_override:
847 prevent_duplicates(
848 self.partial_kwargs,
849 mapped_kwargs,
850 fail_reason="unmappable or already specified",
851 )
852
853 # Ordering is significant; mapped kwargs should override partial ones.
854 trigger_kwargs = mapped_kwargs.get(
855 "trigger_kwargs",
856 self.partial_kwargs.get("trigger_kwargs", self.start_trigger_args.trigger_kwargs),
857 )
858 next_kwargs = mapped_kwargs.get(
859 "next_kwargs",
860 self.partial_kwargs.get("next_kwargs", self.start_trigger_args.next_kwargs),
861 )
862 timeout = mapped_kwargs.get(
863 "trigger_timeout", self.partial_kwargs.get("trigger_timeout", self.start_trigger_args.timeout)
864 )
865 return StartTriggerArgs(
866 trigger_cls=self.start_trigger_args.trigger_cls,
867 trigger_kwargs=trigger_kwargs,
868 next_method=self.start_trigger_args.next_method,
869 next_kwargs=next_kwargs,
870 timeout=timeout,
871 )