1#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18from __future__ import annotations
19
20import contextlib
21import copy
22import warnings
23from collections.abc import Collection, Iterable, Iterator, Mapping, Sequence
24from typing import TYPE_CHECKING, Any, ClassVar, Literal, TypeGuard
25
26import attrs
27import methodtools
28from lazy_object_proxy import Proxy
29
30from airflow.sdk.bases.xcom import BaseXCom
31from airflow.sdk.definitions._internal.abstractoperator import (
32 DEFAULT_EXECUTOR,
33 DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST,
34 DEFAULT_OWNER,
35 DEFAULT_POOL_NAME,
36 DEFAULT_POOL_SLOTS,
37 DEFAULT_PRIORITY_WEIGHT,
38 DEFAULT_QUEUE,
39 DEFAULT_RETRIES,
40 DEFAULT_RETRY_DELAY,
41 DEFAULT_TRIGGER_RULE,
42 DEFAULT_WAIT_FOR_PAST_DEPENDS_BEFORE_SKIPPING,
43 DEFAULT_WEIGHT_RULE,
44 AbstractOperator,
45 NotMapped,
46 TaskStateChangeCallbackAttrType,
47)
48from airflow.sdk.definitions._internal.expandinput import (
49 DictOfListsExpandInput,
50 ListOfDictsExpandInput,
51 is_mappable,
52)
53from airflow.sdk.definitions._internal.types import NOTSET
54from airflow.serialization.enums import DagAttributeTypes
55from airflow.task.priority_strategy import PriorityWeightStrategy, validate_and_load_priority_weight_strategy
56
57if TYPE_CHECKING:
58 import datetime
59
60 import jinja2 # Slow import.
61 import pendulum
62
63 from airflow.models.expandinput import (
64 OperatorExpandArgument,
65 OperatorExpandKwargsArgument,
66 )
67 from airflow.sdk import DAG, BaseOperator, BaseOperatorLink, Context, TaskGroup, TriggerRule, XComArg
68 from airflow.sdk.definitions._internal.expandinput import ExpandInput
69 from airflow.sdk.definitions.operator_resources import Resources
70 from airflow.sdk.definitions.param import ParamsDict
71 from airflow.triggers.base import StartTriggerArgs
72
73ValidationSource = Literal["expand"] | Literal["partial"]
74
75
76def validate_mapping_kwargs(op: type[BaseOperator], func: ValidationSource, value: dict[str, Any]) -> None:
77 # use a dict so order of args is same as code order
78 unknown_args = value.copy()
79 for klass in op.mro():
80 init = klass.__init__ # type: ignore[misc]
81 try:
82 param_names = init._BaseOperatorMeta__param_names
83 except AttributeError:
84 continue
85 for name in param_names:
86 value = unknown_args.pop(name, NOTSET)
87 if func != "expand":
88 continue
89 if value is NOTSET:
90 continue
91 if is_mappable(value):
92 continue
93 type_name = type(value).__name__
94 error = f"{op.__name__}.expand() got an unexpected type {type_name!r} for keyword argument {name}"
95 raise ValueError(error)
96 if not unknown_args:
97 return # If we have no args left to check: stop looking at the MRO chain.
98
99 if len(unknown_args) == 1:
100 error = f"an unexpected keyword argument {unknown_args.popitem()[0]!r}"
101 else:
102 names = ", ".join(repr(n) for n in unknown_args)
103 error = f"unexpected keyword arguments {names}"
104 raise TypeError(f"{op.__name__}.{func}() got {error}")
105
106
107def _is_container(obj: Any) -> bool:
108 """Test if an object is a container (iterable) but not a string."""
109 if isinstance(obj, Proxy):
110 # Proxy of any object is considered a container because it implements __iter__
111 # to forward the call to the lazily initialized object
112 # Unwrap Proxy before checking __iter__ to evaluate the proxied object
113 obj = obj.__wrapped__
114 return hasattr(obj, "__iter__") and not isinstance(obj, str)
115
116
117def ensure_xcomarg_return_value(arg: Any) -> None:
118 from airflow.sdk.definitions.xcom_arg import XComArg
119
120 if isinstance(arg, XComArg):
121 for operator, key in arg.iter_references():
122 if key != BaseXCom.XCOM_RETURN_KEY:
123 raise ValueError(f"cannot map over XCom with custom key {key!r} from {operator}")
124 elif not _is_container(arg):
125 return
126 elif isinstance(arg, Mapping):
127 for v in arg.values():
128 ensure_xcomarg_return_value(v)
129 elif isinstance(arg, Iterable):
130 for v in arg:
131 ensure_xcomarg_return_value(v)
132
133
134def is_mappable_value(value: Any) -> TypeGuard[Collection]:
135 """
136 Whether a value can be used for task mapping.
137
138 We only allow collections with guaranteed ordering, but exclude character
139 sequences since that's usually not what users would expect to be mappable.
140
141 :meta private:
142 """
143 if not isinstance(value, (Sequence, dict)):
144 return False
145 if isinstance(value, (bytearray, bytes, str)):
146 return False
147 return True
148
149
150def prevent_duplicates(kwargs1: dict[str, Any], kwargs2: Mapping[str, Any], *, fail_reason: str) -> None:
151 """
152 Ensure *kwargs1* and *kwargs2* do not contain common keys.
153
154 :raises TypeError: If common keys are found.
155 """
156 duplicated_keys = set(kwargs1).intersection(kwargs2)
157 if not duplicated_keys:
158 return
159 if len(duplicated_keys) == 1:
160 raise TypeError(f"{fail_reason} argument: {duplicated_keys.pop()}")
161 duplicated_keys_display = ", ".join(sorted(duplicated_keys))
162 raise TypeError(f"{fail_reason} arguments: {duplicated_keys_display}")
163
164
165@attrs.define(kw_only=True, repr=False)
166class OperatorPartial:
167 """
168 An "intermediate state" returned by ``BaseOperator.partial()``.
169
170 This only exists at Dag-parsing time; the only intended usage is for the
171 user to call ``.expand()`` on it at some point (usually in a method chain) to
172 create a ``MappedOperator`` to add into the Dag.
173 """
174
175 operator_class: type[BaseOperator]
176 kwargs: dict[str, Any]
177 params: ParamsDict | dict
178
179 _expand_called: bool = False # Set when expand() is called to ease user debugging.
180
181 def __attrs_post_init__(self):
182 validate_mapping_kwargs(self.operator_class, "partial", self.kwargs)
183
184 def __repr__(self) -> str:
185 args = ", ".join(f"{k}={v!r}" for k, v in self.kwargs.items())
186 return f"{self.operator_class.__name__}.partial({args})"
187
188 def __del__(self):
189 if not self._expand_called:
190 try:
191 task_id = repr(self.kwargs["task_id"])
192 except KeyError:
193 task_id = f"at {hex(id(self))}"
194 warnings.warn(f"Task {task_id} was never mapped!", category=UserWarning, stacklevel=1)
195
196 def expand(self, **mapped_kwargs: OperatorExpandArgument) -> MappedOperator:
197 if not mapped_kwargs:
198 raise TypeError("no arguments to expand against")
199 validate_mapping_kwargs(self.operator_class, "expand", mapped_kwargs)
200 prevent_duplicates(self.kwargs, mapped_kwargs, fail_reason="unmappable or already specified")
201 # Since the input is already checked at parse time, we can set strict
202 # to False to skip the checks on execution.
203 return self._expand(DictOfListsExpandInput(mapped_kwargs), strict=False)
204
205 def expand_kwargs(self, kwargs: OperatorExpandKwargsArgument, *, strict: bool = True) -> MappedOperator:
206 from airflow.sdk.definitions.xcom_arg import XComArg
207
208 if isinstance(kwargs, Sequence):
209 for item in kwargs:
210 if not isinstance(item, (XComArg, Mapping)):
211 raise TypeError(f"expected XComArg or list[dict], not {type(kwargs).__name__}")
212 elif not isinstance(kwargs, XComArg):
213 raise TypeError(f"expected XComArg or list[dict], not {type(kwargs).__name__}")
214 return self._expand(ListOfDictsExpandInput(kwargs), strict=strict)
215
216 def _expand(self, expand_input: ExpandInput, *, strict: bool) -> MappedOperator:
217 from airflow.providers.standard.operators.empty import EmptyOperator
218 from airflow.providers.standard.utils.skipmixin import SkipMixin
219 from airflow.sdk import BaseSensorOperator
220
221 self._expand_called = True
222 ensure_xcomarg_return_value(expand_input.value)
223
224 partial_kwargs = self.kwargs.copy()
225 task_id = partial_kwargs.pop("task_id")
226 dag = partial_kwargs.pop("dag")
227 task_group = partial_kwargs.pop("task_group")
228 start_date = partial_kwargs.pop("start_date", None)
229 end_date = partial_kwargs.pop("end_date", None)
230
231 try:
232 operator_name = self.operator_class.custom_operator_name # type: ignore
233 except AttributeError:
234 operator_name = self.operator_class.__name__
235
236 op = MappedOperator(
237 operator_class=self.operator_class,
238 expand_input=expand_input,
239 partial_kwargs=partial_kwargs,
240 task_id=task_id,
241 params=self.params,
242 operator_extra_links=self.operator_class.operator_extra_links,
243 template_ext=self.operator_class.template_ext,
244 template_fields=self.operator_class.template_fields,
245 template_fields_renderers=self.operator_class.template_fields_renderers,
246 ui_color=self.operator_class.ui_color,
247 ui_fgcolor=self.operator_class.ui_fgcolor,
248 is_empty=issubclass(self.operator_class, EmptyOperator),
249 is_sensor=issubclass(self.operator_class, BaseSensorOperator),
250 can_skip_downstream=issubclass(self.operator_class, SkipMixin),
251 task_module=self.operator_class.__module__,
252 task_type=self.operator_class.__name__,
253 operator_name=operator_name,
254 dag=dag,
255 task_group=task_group,
256 start_date=start_date,
257 end_date=end_date,
258 disallow_kwargs_override=strict,
259 # For classic operators, this points to expand_input because kwargs
260 # to BaseOperator.expand() contribute to operator arguments.
261 expand_input_attr="expand_input",
262 # TODO: Move these to task SDK's BaseOperator and remove getattr
263 start_trigger_args=getattr(self.operator_class, "start_trigger_args", None),
264 start_from_trigger=bool(getattr(self.operator_class, "start_from_trigger", False)),
265 )
266 return op
267
268
269@attrs.define(
270 kw_only=True,
271 # Disable custom __getstate__ and __setstate__ generation since it interacts
272 # badly with Airflow's Dag serialization and pickling. When a mapped task is
273 # deserialized, subclasses are coerced into MappedOperator, but when it goes
274 # through Dag pickling, all attributes defined in the subclasses are dropped
275 # by attrs's custom state management. Since attrs does not do anything too
276 # special here (the logic is only important for slots=True), we use Python's
277 # built-in implementation, which works (as proven by good old BaseOperator).
278 getstate_setstate=False,
279)
280class MappedOperator(AbstractOperator):
281 """Object representing a mapped operator in a Dag."""
282
283 operator_class: type[BaseOperator]
284
285 _is_mapped: bool = attrs.field(init=False, default=True)
286
287 expand_input: ExpandInput
288 partial_kwargs: dict[str, Any]
289
290 # Needed for serialization.
291 task_id: str
292 params: ParamsDict | dict
293 operator_extra_links: Collection[BaseOperatorLink]
294 template_ext: Sequence[str]
295 template_fields: Collection[str]
296 template_fields_renderers: dict[str, str]
297 ui_color: str
298 ui_fgcolor: str
299 _is_empty: bool = attrs.field(alias="is_empty")
300 _can_skip_downstream: bool = attrs.field(alias="can_skip_downstream")
301 _is_sensor: bool = attrs.field(alias="is_sensor", default=False)
302 _task_module: str
303 task_type: str
304 _operator_name: str
305 start_trigger_args: StartTriggerArgs | None
306 start_from_trigger: bool
307 _needs_expansion: bool = True
308
309 dag: DAG | None
310 task_group: TaskGroup | None
311 start_date: pendulum.DateTime | None
312 end_date: pendulum.DateTime | None
313 upstream_task_ids: set[str] = attrs.field(factory=set, init=False)
314 downstream_task_ids: set[str] = attrs.field(factory=set, init=False)
315
316 _disallow_kwargs_override: bool
317 """Whether execution fails if ``expand_input`` has duplicates to ``partial_kwargs``.
318
319 If *False*, values from ``expand_input`` under duplicate keys override those
320 under corresponding keys in ``partial_kwargs``.
321 """
322
323 _expand_input_attr: str
324 """Where to get kwargs to calculate expansion length against.
325
326 This should be a name to call ``getattr()`` on.
327 """
328
329 HIDE_ATTRS_FROM_UI: ClassVar[frozenset[str]] = AbstractOperator.HIDE_ATTRS_FROM_UI | frozenset(
330 ("parse_time_mapped_ti_count", "operator_class", "start_trigger_args", "start_from_trigger")
331 )
332
333 def __hash__(self):
334 return id(self)
335
336 def __repr__(self):
337 return f"<Mapped({self.task_type}): {self.task_id}>"
338
339 def __attrs_post_init__(self):
340 from airflow.sdk.definitions.xcom_arg import XComArg
341
342 if self.get_closest_mapped_task_group() is not None:
343 raise NotImplementedError("operator expansion in an expanded task group is not yet supported")
344
345 if self.task_group:
346 self.task_group.add(self)
347 if self.dag:
348 self.dag.add_task(self)
349 XComArg.apply_upstream_relationship(self, self._get_specified_expand_input().value)
350 for k, v in self.partial_kwargs.items():
351 if k in self.template_fields:
352 XComArg.apply_upstream_relationship(self, v)
353
354 @methodtools.lru_cache(maxsize=None)
355 @classmethod
356 def get_serialized_fields(cls):
357 # Not using 'cls' here since we only want to serialize base fields.
358 return (frozenset(attrs.fields_dict(MappedOperator))) - {
359 "_is_empty",
360 "_can_skip_downstream",
361 "dag",
362 "deps",
363 "expand_input", # This is needed to be able to accept XComArg.
364 "task_group",
365 "upstream_task_ids",
366 "_is_setup",
367 "_is_teardown",
368 "_on_failure_fail_dagrun",
369 "operator_class",
370 "_needs_expansion",
371 "partial_kwargs",
372 "operator_extra_links",
373 }
374
375 @property
376 def operator_name(self) -> str:
377 return self._operator_name
378
379 @property
380 def roots(self) -> Sequence[AbstractOperator]:
381 """Implementing DAGNode."""
382 return [self]
383
384 @property
385 def leaves(self) -> Sequence[AbstractOperator]:
386 """Implementing DAGNode."""
387 return [self]
388
389 @property
390 def task_display_name(self) -> str:
391 return self.partial_kwargs.get("task_display_name") or self.task_id
392
393 @property
394 def owner(self) -> str:
395 return self.partial_kwargs.get("owner", DEFAULT_OWNER)
396
397 @owner.setter
398 def owner(self, value: str) -> None:
399 self.partial_kwargs["owner"] = value
400
401 @property
402 def email(self) -> None | str | Iterable[str]:
403 return self.partial_kwargs.get("email")
404
405 @property
406 def email_on_failure(self) -> bool:
407 return self.partial_kwargs.get("email_on_failure", True)
408
409 @property
410 def email_on_retry(self) -> bool:
411 return self.partial_kwargs.get("email_on_retry", True)
412
413 @property
414 def map_index_template(self) -> None | str:
415 return self.partial_kwargs.get("map_index_template")
416
417 @map_index_template.setter
418 def map_index_template(self, value: str | None) -> None:
419 self.partial_kwargs["map_index_template"] = value
420
421 @property
422 def trigger_rule(self) -> TriggerRule:
423 return self.partial_kwargs.get("trigger_rule", DEFAULT_TRIGGER_RULE)
424
425 @trigger_rule.setter
426 def trigger_rule(self, value):
427 self.partial_kwargs["trigger_rule"] = value
428
429 @property
430 def is_setup(self) -> bool:
431 return bool(self.partial_kwargs.get("is_setup"))
432
433 @is_setup.setter
434 def is_setup(self, value: bool) -> None:
435 self.partial_kwargs["is_setup"] = value
436
437 @property
438 def is_teardown(self) -> bool:
439 return bool(self.partial_kwargs.get("is_teardown"))
440
441 @is_teardown.setter
442 def is_teardown(self, value: bool) -> None:
443 self.partial_kwargs["is_teardown"] = value
444
445 @property
446 def depends_on_past(self) -> bool:
447 return bool(self.partial_kwargs.get("depends_on_past"))
448
449 @depends_on_past.setter
450 def depends_on_past(self, value: bool) -> None:
451 self.partial_kwargs["depends_on_past"] = value
452
453 @property
454 def ignore_first_depends_on_past(self) -> bool:
455 value = self.partial_kwargs.get("ignore_first_depends_on_past", DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST)
456 return bool(value)
457
458 @ignore_first_depends_on_past.setter
459 def ignore_first_depends_on_past(self, value: bool) -> None:
460 self.partial_kwargs["ignore_first_depends_on_past"] = value
461
462 @property
463 def wait_for_past_depends_before_skipping(self) -> bool:
464 value = self.partial_kwargs.get(
465 "wait_for_past_depends_before_skipping", DEFAULT_WAIT_FOR_PAST_DEPENDS_BEFORE_SKIPPING
466 )
467 return bool(value)
468
469 @wait_for_past_depends_before_skipping.setter
470 def wait_for_past_depends_before_skipping(self, value: bool) -> None:
471 self.partial_kwargs["wait_for_past_depends_before_skipping"] = value
472
473 @property
474 def wait_for_downstream(self) -> bool:
475 return bool(self.partial_kwargs.get("wait_for_downstream"))
476
477 @wait_for_downstream.setter
478 def wait_for_downstream(self, value: bool) -> None:
479 self.partial_kwargs["wait_for_downstream"] = value
480
481 @property
482 def retries(self) -> int:
483 return self.partial_kwargs.get("retries", DEFAULT_RETRIES)
484
485 @retries.setter
486 def retries(self, value: int) -> None:
487 self.partial_kwargs["retries"] = value
488
489 @property
490 def queue(self) -> str:
491 return self.partial_kwargs.get("queue", DEFAULT_QUEUE)
492
493 @queue.setter
494 def queue(self, value: str) -> None:
495 self.partial_kwargs["queue"] = value
496
497 @property
498 def pool(self) -> str:
499 return self.partial_kwargs.get("pool", DEFAULT_POOL_NAME)
500
501 @pool.setter
502 def pool(self, value: str) -> None:
503 self.partial_kwargs["pool"] = value
504
505 @property
506 def pool_slots(self) -> int:
507 return self.partial_kwargs.get("pool_slots", DEFAULT_POOL_SLOTS)
508
509 @pool_slots.setter
510 def pool_slots(self, value: int) -> None:
511 self.partial_kwargs["pool_slots"] = value
512
513 @property
514 def execution_timeout(self) -> datetime.timedelta | None:
515 return self.partial_kwargs.get("execution_timeout")
516
517 @execution_timeout.setter
518 def execution_timeout(self, value: datetime.timedelta | None) -> None:
519 self.partial_kwargs["execution_timeout"] = value
520
521 @property
522 def max_retry_delay(self) -> datetime.timedelta | None:
523 return self.partial_kwargs.get("max_retry_delay")
524
525 @max_retry_delay.setter
526 def max_retry_delay(self, value: datetime.timedelta | None) -> None:
527 self.partial_kwargs["max_retry_delay"] = value
528
529 @property
530 def retry_delay(self) -> datetime.timedelta:
531 return self.partial_kwargs.get("retry_delay", DEFAULT_RETRY_DELAY)
532
533 @retry_delay.setter
534 def retry_delay(self, value: datetime.timedelta) -> None:
535 self.partial_kwargs["retry_delay"] = value
536
537 @property
538 def retry_exponential_backoff(self) -> float:
539 value = self.partial_kwargs.get("retry_exponential_backoff", 0)
540 if value is True:
541 return 2.0
542 if value is False:
543 return 0.0
544 return float(value)
545
546 @retry_exponential_backoff.setter
547 def retry_exponential_backoff(self, value: float) -> None:
548 self.partial_kwargs["retry_exponential_backoff"] = value
549
550 @property
551 def priority_weight(self) -> int:
552 return self.partial_kwargs.get("priority_weight", DEFAULT_PRIORITY_WEIGHT)
553
554 @priority_weight.setter
555 def priority_weight(self, value: int) -> None:
556 self.partial_kwargs["priority_weight"] = value
557
558 @property
559 def weight_rule(self) -> PriorityWeightStrategy:
560 return validate_and_load_priority_weight_strategy(
561 self.partial_kwargs.get("weight_rule", DEFAULT_WEIGHT_RULE)
562 )
563
564 @weight_rule.setter
565 def weight_rule(self, value: str | PriorityWeightStrategy) -> None:
566 self.partial_kwargs["weight_rule"] = validate_and_load_priority_weight_strategy(value)
567
568 @property
569 def max_active_tis_per_dag(self) -> int | None:
570 return self.partial_kwargs.get("max_active_tis_per_dag")
571
572 @max_active_tis_per_dag.setter
573 def max_active_tis_per_dag(self, value: int | None) -> None:
574 self.partial_kwargs["max_active_tis_per_dag"] = value
575
576 @property
577 def max_active_tis_per_dagrun(self) -> int | None:
578 return self.partial_kwargs.get("max_active_tis_per_dagrun")
579
580 @max_active_tis_per_dagrun.setter
581 def max_active_tis_per_dagrun(self, value: int | None) -> None:
582 self.partial_kwargs["max_active_tis_per_dagrun"] = value
583
584 @property
585 def resources(self) -> Resources | None:
586 return self.partial_kwargs.get("resources")
587
588 @property
589 def on_execute_callback(self) -> TaskStateChangeCallbackAttrType:
590 return self.partial_kwargs.get("on_execute_callback") or []
591
592 @on_execute_callback.setter
593 def on_execute_callback(self, value: TaskStateChangeCallbackAttrType) -> None:
594 self.partial_kwargs["on_execute_callback"] = value or []
595
596 @property
597 def on_failure_callback(self) -> TaskStateChangeCallbackAttrType:
598 return self.partial_kwargs.get("on_failure_callback") or []
599
600 @on_failure_callback.setter
601 def on_failure_callback(self, value: TaskStateChangeCallbackAttrType) -> None:
602 self.partial_kwargs["on_failure_callback"] = value or []
603
604 @property
605 def on_retry_callback(self) -> TaskStateChangeCallbackAttrType:
606 return self.partial_kwargs.get("on_retry_callback") or []
607
608 @on_retry_callback.setter
609 def on_retry_callback(self, value: TaskStateChangeCallbackAttrType) -> None:
610 self.partial_kwargs["on_retry_callback"] = value or []
611
612 @property
613 def on_success_callback(self) -> TaskStateChangeCallbackAttrType:
614 return self.partial_kwargs.get("on_success_callback") or []
615
616 @on_success_callback.setter
617 def on_success_callback(self, value: TaskStateChangeCallbackAttrType) -> None:
618 self.partial_kwargs["on_success_callback"] = value or []
619
620 @property
621 def on_skipped_callback(self) -> TaskStateChangeCallbackAttrType:
622 return self.partial_kwargs.get("on_skipped_callback") or []
623
624 @on_skipped_callback.setter
625 def on_skipped_callback(self, value: TaskStateChangeCallbackAttrType) -> None:
626 self.partial_kwargs["on_skipped_callback"] = value or []
627
628 @property
629 def has_on_execute_callback(self) -> bool:
630 return bool(self.on_execute_callback)
631
632 @property
633 def has_on_failure_callback(self) -> bool:
634 return bool(self.on_failure_callback)
635
636 @property
637 def has_on_retry_callback(self) -> bool:
638 return bool(self.on_retry_callback)
639
640 @property
641 def has_on_success_callback(self) -> bool:
642 return bool(self.on_success_callback)
643
644 @property
645 def has_on_skipped_callback(self) -> bool:
646 return bool(self.on_skipped_callback)
647
648 @property
649 def run_as_user(self) -> str | None:
650 return self.partial_kwargs.get("run_as_user")
651
652 @property
653 def executor(self) -> str | None:
654 return self.partial_kwargs.get("executor", DEFAULT_EXECUTOR)
655
656 @property
657 def executor_config(self) -> dict:
658 return self.partial_kwargs.get("executor_config", {})
659
660 @property
661 def inlets(self) -> list[Any]:
662 return self.partial_kwargs.get("inlets", [])
663
664 @inlets.setter
665 def inlets(self, value: list[Any]) -> None:
666 self.partial_kwargs["inlets"] = value
667
668 @property
669 def outlets(self) -> list[Any]:
670 return self.partial_kwargs.get("outlets", [])
671
672 @outlets.setter
673 def outlets(self, value: list[Any]) -> None:
674 self.partial_kwargs["outlets"] = value
675
676 @property
677 def doc(self) -> str | None:
678 return self.partial_kwargs.get("doc")
679
680 @property
681 def doc_md(self) -> str | None:
682 return self.partial_kwargs.get("doc_md")
683
684 @property
685 def doc_json(self) -> str | None:
686 return self.partial_kwargs.get("doc_json")
687
688 @property
689 def doc_yaml(self) -> str | None:
690 return self.partial_kwargs.get("doc_yaml")
691
692 @property
693 def doc_rst(self) -> str | None:
694 return self.partial_kwargs.get("doc_rst")
695
696 @property
697 def allow_nested_operators(self) -> bool:
698 return bool(self.partial_kwargs.get("allow_nested_operators"))
699
700 def get_dag(self) -> DAG | None:
701 """Implement Operator."""
702 return self.dag
703
704 @property
705 def output(self) -> XComArg:
706 """Return reference to XCom pushed by current operator."""
707 from airflow.sdk.definitions.xcom_arg import XComArg
708
709 return XComArg(operator=self)
710
711 def serialize_for_task_group(self) -> tuple[DagAttributeTypes, Any]:
712 """Implement DAGNode."""
713 return DagAttributeTypes.OP, self.task_id
714
715 def _expand_mapped_kwargs(self, context: Mapping[str, Any]) -> tuple[Mapping[str, Any], set[int]]:
716 """
717 Get the kwargs to create the unmapped operator.
718
719 This exists because taskflow operators expand against op_kwargs, not the
720 entire operator kwargs dict.
721 """
722 return self._get_specified_expand_input().resolve(context)
723
724 def _get_unmap_kwargs(self, mapped_kwargs: Mapping[str, Any], *, strict: bool) -> dict[str, Any]:
725 """
726 Get init kwargs to unmap the underlying operator class.
727
728 :param mapped_kwargs: The dict returned by ``_expand_mapped_kwargs``.
729 """
730 if strict:
731 prevent_duplicates(
732 self.partial_kwargs,
733 mapped_kwargs,
734 fail_reason="unmappable or already specified",
735 )
736
737 # If params appears in the mapped kwargs, we need to merge it into the
738 # partial params, overriding existing keys.
739 params = copy.copy(self.params)
740 with contextlib.suppress(KeyError):
741 params.update(mapped_kwargs["params"])
742
743 # Ordering is significant; mapped kwargs should override partial ones,
744 # and the specially handled params should be respected.
745 return {
746 "task_id": self.task_id,
747 "dag": self.dag,
748 "task_group": self.task_group,
749 "start_date": self.start_date,
750 "end_date": self.end_date,
751 **self.partial_kwargs,
752 **mapped_kwargs,
753 "params": params,
754 }
755
756 def unmap(self, resolve: None | Mapping[str, Any]) -> BaseOperator:
757 """
758 Get the "normal" Operator after applying the current mapping.
759
760 :meta private:
761 """
762 if isinstance(resolve, Mapping):
763 kwargs = resolve
764 elif resolve is not None:
765 kwargs, _ = self._expand_mapped_kwargs(*resolve)
766 else:
767 raise RuntimeError("cannot unmap a non-serialized operator without context")
768 kwargs = self._get_unmap_kwargs(kwargs, strict=self._disallow_kwargs_override)
769 is_setup = kwargs.pop("is_setup", False)
770 is_teardown = kwargs.pop("is_teardown", False)
771 on_failure_fail_dagrun = kwargs.pop("on_failure_fail_dagrun", False)
772 kwargs["task_id"] = self.task_id
773 op = self.operator_class(**kwargs, _airflow_from_mapped=True)
774 op.is_setup = is_setup
775 op.is_teardown = is_teardown
776 op.on_failure_fail_dagrun = on_failure_fail_dagrun
777 op.downstream_task_ids = self.downstream_task_ids
778 op.upstream_task_ids = self.upstream_task_ids
779 return op
780
781 def _get_specified_expand_input(self) -> ExpandInput:
782 """Input received from the expand call on the operator."""
783 return getattr(self, self._expand_input_attr)
784
785 def prepare_for_execution(self) -> MappedOperator:
786 # Since a mapped operator cannot be used for execution, and an unmapped
787 # BaseOperator needs to be created later (see render_template_fields),
788 # we don't need to create a copy of the MappedOperator here.
789 return self
790
791 # TODO (GH-52141): Do we need this in the SDK?
792 def iter_mapped_dependencies(self) -> Iterator[AbstractOperator]:
793 """Upstream dependencies that provide XComs used by this task for task mapping."""
794 from airflow.sdk.definitions.xcom_arg import XComArg
795
796 for operator, _ in XComArg.iter_xcom_references(self._get_specified_expand_input()):
797 yield operator
798
799 @methodtools.lru_cache(maxsize=None)
800 def get_parse_time_mapped_ti_count(self) -> int:
801 current_count = self._get_specified_expand_input().get_parse_time_mapped_ti_count()
802 try:
803 # The use of `methodtools` interferes with the zero-arg super
804 parent_count = super(MappedOperator, self).get_parse_time_mapped_ti_count() # noqa: UP008
805 except NotMapped:
806 return current_count
807 return parent_count * current_count
808
809 def render_template_fields(
810 self,
811 context: Context,
812 jinja_env: jinja2.Environment | None = None,
813 ) -> None:
814 """
815 Template all attributes listed in *self.template_fields*.
816
817 This updates *context* to reference the map-expanded task and relevant
818 information, without modifying the mapped operator. The expanded task
819 in *context* is then rendered in-place.
820
821 :param context: Context dict with values to apply on content.
822 :param jinja_env: Jinja environment to use for rendering.
823 """
824 from airflow.sdk.execution_time.context import context_update_for_unmapped
825
826 if not jinja_env:
827 jinja_env = self.get_template_env()
828
829 mapped_kwargs, seen_oids = self._expand_mapped_kwargs(context)
830 unmapped_task = self.unmap(mapped_kwargs)
831 context_update_for_unmapped(context, unmapped_task)
832
833 # Since the operators that extend `BaseOperator` are not subclasses of
834 # `MappedOperator`, we need to call `_do_render_template_fields` from
835 # the unmapped task in order to call the operator method when we override
836 # it to customize the parsing of nested fields.
837 unmapped_task._do_render_template_fields(
838 parent=unmapped_task,
839 template_fields=self.template_fields,
840 context=context,
841 jinja_env=jinja_env,
842 seen_oids=seen_oids,
843 )
844
845 def expand_start_trigger_args(self, *, context: Context) -> StartTriggerArgs | None:
846 """
847 Get the kwargs to create the unmapped start_trigger_args.
848
849 This method is for allowing mapped operator to start execution from triggerer.
850 """
851 from airflow.triggers.base import StartTriggerArgs
852
853 if not self.start_trigger_args:
854 return None
855
856 mapped_kwargs, _ = self._expand_mapped_kwargs(context)
857 if self._disallow_kwargs_override:
858 prevent_duplicates(
859 self.partial_kwargs,
860 mapped_kwargs,
861 fail_reason="unmappable or already specified",
862 )
863
864 # Ordering is significant; mapped kwargs should override partial ones.
865 trigger_kwargs = mapped_kwargs.get(
866 "trigger_kwargs",
867 self.partial_kwargs.get("trigger_kwargs", self.start_trigger_args.trigger_kwargs),
868 )
869 next_kwargs = mapped_kwargs.get(
870 "next_kwargs",
871 self.partial_kwargs.get("next_kwargs", self.start_trigger_args.next_kwargs),
872 )
873 timeout = mapped_kwargs.get(
874 "trigger_timeout", self.partial_kwargs.get("trigger_timeout", self.start_trigger_args.timeout)
875 )
876 return StartTriggerArgs(
877 trigger_cls=self.start_trigger_args.trigger_cls,
878 trigger_kwargs=trigger_kwargs,
879 next_method=self.start_trigger_args.next_method,
880 next_kwargs=next_kwargs,
881 timeout=timeout,
882 )