Coverage for /pythoncovmergedfiles/medio/medio/src/airflow/airflow/models/mappedoperator.py: 49%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18from __future__ import annotations
20import collections.abc
21import contextlib
22import copy
23import warnings
24from typing import TYPE_CHECKING, Any, ClassVar, Collection, Iterable, Iterator, Mapping, Sequence, Union
26import attr
27import methodtools
29from airflow.exceptions import AirflowException, UnmappableOperator
30from airflow.models.abstractoperator import (
31 DEFAULT_EXECUTOR,
32 DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST,
33 DEFAULT_OWNER,
34 DEFAULT_POOL_SLOTS,
35 DEFAULT_PRIORITY_WEIGHT,
36 DEFAULT_QUEUE,
37 DEFAULT_RETRIES,
38 DEFAULT_RETRY_DELAY,
39 DEFAULT_TRIGGER_RULE,
40 DEFAULT_WAIT_FOR_PAST_DEPENDS_BEFORE_SKIPPING,
41 DEFAULT_WEIGHT_RULE,
42 AbstractOperator,
43 NotMapped,
44)
45from airflow.models.expandinput import (
46 DictOfListsExpandInput,
47 ListOfDictsExpandInput,
48 is_mappable,
49)
50from airflow.models.pool import Pool
51from airflow.serialization.enums import DagAttributeTypes
52from airflow.task.priority_strategy import PriorityWeightStrategy, validate_and_load_priority_weight_strategy
53from airflow.ti_deps.deps.mapped_task_expanded import MappedTaskIsExpanded
54from airflow.typing_compat import Literal
55from airflow.utils.context import context_update_for_unmapped
56from airflow.utils.helpers import is_container, prevent_duplicates
57from airflow.utils.task_instance_session import get_current_task_instance_session
58from airflow.utils.types import NOTSET
59from airflow.utils.xcom import XCOM_RETURN_KEY
61if TYPE_CHECKING:
62 import datetime
63 from typing import List
65 import jinja2 # Slow import.
66 import pendulum
67 from sqlalchemy.orm.session import Session
69 from airflow.models.abstractoperator import (
70 TaskStateChangeCallback,
71 )
72 from airflow.models.baseoperator import BaseOperator
73 from airflow.models.baseoperatorlink import BaseOperatorLink
74 from airflow.models.dag import DAG
75 from airflow.models.expandinput import (
76 ExpandInput,
77 OperatorExpandArgument,
78 OperatorExpandKwargsArgument,
79 )
80 from airflow.models.operator import Operator
81 from airflow.models.param import ParamsDict
82 from airflow.models.xcom_arg import XComArg
83 from airflow.ti_deps.deps.base_ti_dep import BaseTIDep
84 from airflow.triggers.base import BaseTrigger
85 from airflow.utils.context import Context
86 from airflow.utils.operator_resources import Resources
87 from airflow.utils.task_group import TaskGroup
88 from airflow.utils.trigger_rule import TriggerRule
90 TaskStateChangeCallbackAttrType = Union[None, TaskStateChangeCallback, List[TaskStateChangeCallback]]
92ValidationSource = Union[Literal["expand"], Literal["partial"]]
95def validate_mapping_kwargs(op: type[BaseOperator], func: ValidationSource, value: dict[str, Any]) -> None:
96 # use a dict so order of args is same as code order
97 unknown_args = value.copy()
98 for klass in op.mro():
99 init = klass.__init__ # type: ignore[misc]
100 try:
101 param_names = init._BaseOperatorMeta__param_names
102 except AttributeError:
103 continue
104 for name in param_names:
105 value = unknown_args.pop(name, NOTSET)
106 if func != "expand":
107 continue
108 if value is NOTSET:
109 continue
110 if is_mappable(value):
111 continue
112 type_name = type(value).__name__
113 error = f"{op.__name__}.expand() got an unexpected type {type_name!r} for keyword argument {name}"
114 raise ValueError(error)
115 if not unknown_args:
116 return # If we have no args left to check: stop looking at the MRO chain.
118 if len(unknown_args) == 1:
119 error = f"an unexpected keyword argument {unknown_args.popitem()[0]!r}"
120 else:
121 names = ", ".join(repr(n) for n in unknown_args)
122 error = f"unexpected keyword arguments {names}"
123 raise TypeError(f"{op.__name__}.{func}() got {error}")
126def ensure_xcomarg_return_value(arg: Any) -> None:
127 from airflow.models.xcom_arg import XComArg
129 if isinstance(arg, XComArg):
130 for operator, key in arg.iter_references():
131 if key != XCOM_RETURN_KEY:
132 raise ValueError(f"cannot map over XCom with custom key {key!r} from {operator}")
133 elif not is_container(arg):
134 return
135 elif isinstance(arg, collections.abc.Mapping):
136 for v in arg.values():
137 ensure_xcomarg_return_value(v)
138 elif isinstance(arg, collections.abc.Iterable):
139 for v in arg:
140 ensure_xcomarg_return_value(v)
143@attr.define(kw_only=True, repr=False)
144class OperatorPartial:
145 """An "intermediate state" returned by ``BaseOperator.partial()``.
147 This only exists at DAG-parsing time; the only intended usage is for the
148 user to call ``.expand()`` on it at some point (usually in a method chain) to
149 create a ``MappedOperator`` to add into the DAG.
150 """
152 operator_class: type[BaseOperator]
153 kwargs: dict[str, Any]
154 params: ParamsDict | dict
156 _expand_called: bool = False # Set when expand() is called to ease user debugging.
158 def __attrs_post_init__(self):
159 from airflow.operators.subdag import SubDagOperator
161 if issubclass(self.operator_class, SubDagOperator):
162 raise TypeError("Mapping over deprecated SubDagOperator is not supported")
163 validate_mapping_kwargs(self.operator_class, "partial", self.kwargs)
165 def __repr__(self) -> str:
166 args = ", ".join(f"{k}={v!r}" for k, v in self.kwargs.items())
167 return f"{self.operator_class.__name__}.partial({args})"
169 def __del__(self):
170 if not self._expand_called:
171 try:
172 task_id = repr(self.kwargs["task_id"])
173 except KeyError:
174 task_id = f"at {hex(id(self))}"
175 warnings.warn(f"Task {task_id} was never mapped!", category=UserWarning, stacklevel=1)
177 def expand(self, **mapped_kwargs: OperatorExpandArgument) -> MappedOperator:
178 if not mapped_kwargs:
179 raise TypeError("no arguments to expand against")
180 validate_mapping_kwargs(self.operator_class, "expand", mapped_kwargs)
181 prevent_duplicates(self.kwargs, mapped_kwargs, fail_reason="unmappable or already specified")
182 # Since the input is already checked at parse time, we can set strict
183 # to False to skip the checks on execution.
184 return self._expand(DictOfListsExpandInput(mapped_kwargs), strict=False)
186 def expand_kwargs(self, kwargs: OperatorExpandKwargsArgument, *, strict: bool = True) -> MappedOperator:
187 from airflow.models.xcom_arg import XComArg
189 if isinstance(kwargs, collections.abc.Sequence):
190 for item in kwargs:
191 if not isinstance(item, (XComArg, collections.abc.Mapping)):
192 raise TypeError(f"expected XComArg or list[dict], not {type(kwargs).__name__}")
193 elif not isinstance(kwargs, XComArg):
194 raise TypeError(f"expected XComArg or list[dict], not {type(kwargs).__name__}")
195 return self._expand(ListOfDictsExpandInput(kwargs), strict=strict)
197 def _expand(self, expand_input: ExpandInput, *, strict: bool) -> MappedOperator:
198 from airflow.operators.empty import EmptyOperator
200 self._expand_called = True
201 ensure_xcomarg_return_value(expand_input.value)
203 partial_kwargs = self.kwargs.copy()
204 task_id = partial_kwargs.pop("task_id")
205 dag = partial_kwargs.pop("dag")
206 task_group = partial_kwargs.pop("task_group")
207 start_date = partial_kwargs.pop("start_date")
208 end_date = partial_kwargs.pop("end_date")
210 try:
211 operator_name = self.operator_class.custom_operator_name # type: ignore
212 except AttributeError:
213 operator_name = self.operator_class.__name__
215 op = MappedOperator(
216 operator_class=self.operator_class,
217 expand_input=expand_input,
218 partial_kwargs=partial_kwargs,
219 task_id=task_id,
220 params=self.params,
221 deps=MappedOperator.deps_for(self.operator_class),
222 operator_extra_links=self.operator_class.operator_extra_links,
223 template_ext=self.operator_class.template_ext,
224 template_fields=self.operator_class.template_fields,
225 template_fields_renderers=self.operator_class.template_fields_renderers,
226 ui_color=self.operator_class.ui_color,
227 ui_fgcolor=self.operator_class.ui_fgcolor,
228 is_empty=issubclass(self.operator_class, EmptyOperator),
229 task_module=self.operator_class.__module__,
230 task_type=self.operator_class.__name__,
231 operator_name=operator_name,
232 dag=dag,
233 task_group=task_group,
234 start_date=start_date,
235 end_date=end_date,
236 disallow_kwargs_override=strict,
237 # For classic operators, this points to expand_input because kwargs
238 # to BaseOperator.expand() contribute to operator arguments.
239 expand_input_attr="expand_input",
240 start_trigger=self.operator_class.start_trigger,
241 next_method=self.operator_class.next_method,
242 )
243 return op
246@attr.define(
247 kw_only=True,
248 # Disable custom __getstate__ and __setstate__ generation since it interacts
249 # badly with Airflow's DAG serialization and pickling. When a mapped task is
250 # deserialized, subclasses are coerced into MappedOperator, but when it goes
251 # through DAG pickling, all attributes defined in the subclasses are dropped
252 # by attrs's custom state management. Since attrs does not do anything too
253 # special here (the logic is only important for slots=True), we use Python's
254 # built-in implementation, which works (as proven by good old BaseOperator).
255 getstate_setstate=False,
256)
257class MappedOperator(AbstractOperator):
258 """Object representing a mapped operator in a DAG."""
260 # This attribute serves double purpose. For a "normal" operator instance
261 # loaded from DAG, this holds the underlying non-mapped operator class that
262 # can be used to create an unmapped operator for execution. For an operator
263 # recreated from a serialized DAG, however, this holds the serialized data
264 # that can be used to unmap this into a SerializedBaseOperator.
265 operator_class: type[BaseOperator] | dict[str, Any]
267 expand_input: ExpandInput
268 partial_kwargs: dict[str, Any]
270 # Needed for serialization.
271 task_id: str
272 params: ParamsDict | dict
273 deps: frozenset[BaseTIDep]
274 operator_extra_links: Collection[BaseOperatorLink]
275 template_ext: Sequence[str]
276 template_fields: Collection[str]
277 template_fields_renderers: dict[str, str]
278 ui_color: str
279 ui_fgcolor: str
280 _is_empty: bool
281 _task_module: str
282 _task_type: str
283 _operator_name: str
284 start_trigger: BaseTrigger | None
285 next_method: str | None
286 _needs_expansion: bool = True
288 dag: DAG | None
289 task_group: TaskGroup | None
290 start_date: pendulum.DateTime | None
291 end_date: pendulum.DateTime | None
292 upstream_task_ids: set[str] = attr.ib(factory=set, init=False)
293 downstream_task_ids: set[str] = attr.ib(factory=set, init=False)
295 _disallow_kwargs_override: bool
296 """Whether execution fails if ``expand_input`` has duplicates to ``partial_kwargs``.
298 If *False*, values from ``expand_input`` under duplicate keys override those
299 under corresponding keys in ``partial_kwargs``.
300 """
302 _expand_input_attr: str
303 """Where to get kwargs to calculate expansion length against.
305 This should be a name to call ``getattr()`` on.
306 """
308 subdag: None = None # Since we don't support SubDagOperator, this is always None.
309 supports_lineage: bool = False
311 HIDE_ATTRS_FROM_UI: ClassVar[frozenset[str]] = AbstractOperator.HIDE_ATTRS_FROM_UI | frozenset(
312 (
313 "parse_time_mapped_ti_count",
314 "operator_class",
315 "start_trigger",
316 "next_method",
317 )
318 )
320 def __hash__(self):
321 return id(self)
323 def __repr__(self):
324 return f"<Mapped({self._task_type}): {self.task_id}>"
326 def __attrs_post_init__(self):
327 from airflow.models.xcom_arg import XComArg
329 if self.get_closest_mapped_task_group() is not None:
330 raise NotImplementedError("operator expansion in an expanded task group is not yet supported")
332 if self.task_group:
333 self.task_group.add(self)
334 if self.dag:
335 self.dag.add_task(self)
336 XComArg.apply_upstream_relationship(self, self.expand_input.value)
337 for k, v in self.partial_kwargs.items():
338 if k in self.template_fields:
339 XComArg.apply_upstream_relationship(self, v)
340 if self.partial_kwargs.get("sla") is not None:
341 raise AirflowException(
342 f"SLAs are unsupported with mapped tasks. Please set `sla=None` for task "
343 f"{self.task_id!r}."
344 )
346 @methodtools.lru_cache(maxsize=None)
347 @classmethod
348 def get_serialized_fields(cls):
349 # Not using 'cls' here since we only want to serialize base fields.
350 return frozenset(attr.fields_dict(MappedOperator)) - {
351 "dag",
352 "deps",
353 "expand_input", # This is needed to be able to accept XComArg.
354 "subdag",
355 "task_group",
356 "upstream_task_ids",
357 "supports_lineage",
358 "_is_setup",
359 "_is_teardown",
360 "_on_failure_fail_dagrun",
361 }
363 @methodtools.lru_cache(maxsize=None)
364 @staticmethod
365 def deps_for(operator_class: type[BaseOperator]) -> frozenset[BaseTIDep]:
366 operator_deps = operator_class.deps
367 if not isinstance(operator_deps, collections.abc.Set):
368 raise UnmappableOperator(
369 f"'deps' must be a set defined as a class-level variable on {operator_class.__name__}, "
370 f"not a {type(operator_deps).__name__}"
371 )
372 return operator_deps | {MappedTaskIsExpanded()}
374 @property
375 def task_type(self) -> str:
376 """Implementing Operator."""
377 return self._task_type
379 @property
380 def operator_name(self) -> str:
381 return self._operator_name
383 @property
384 def inherits_from_empty_operator(self) -> bool:
385 """Implementing Operator."""
386 return self._is_empty
388 @property
389 def roots(self) -> Sequence[AbstractOperator]:
390 """Implementing DAGNode."""
391 return [self]
393 @property
394 def leaves(self) -> Sequence[AbstractOperator]:
395 """Implementing DAGNode."""
396 return [self]
398 @property
399 def task_display_name(self) -> str:
400 return self.partial_kwargs.get("task_display_name") or self.task_id
402 @property
403 def owner(self) -> str: # type: ignore[override]
404 return self.partial_kwargs.get("owner", DEFAULT_OWNER)
406 @property
407 def email(self) -> None | str | Iterable[str]:
408 return self.partial_kwargs.get("email")
410 @property
411 def map_index_template(self) -> None | str:
412 return self.partial_kwargs.get("map_index_template")
414 @map_index_template.setter
415 def map_index_template(self, value: str | None) -> None:
416 self.partial_kwargs["map_index_template"] = value
418 @property
419 def trigger_rule(self) -> TriggerRule:
420 return self.partial_kwargs.get("trigger_rule", DEFAULT_TRIGGER_RULE)
422 @trigger_rule.setter
423 def trigger_rule(self, value):
424 self.partial_kwargs["trigger_rule"] = value
426 @property
427 def is_setup(self) -> bool:
428 return bool(self.partial_kwargs.get("is_setup"))
430 @is_setup.setter
431 def is_setup(self, value: bool) -> None:
432 self.partial_kwargs["is_setup"] = value
434 @property
435 def is_teardown(self) -> bool:
436 return bool(self.partial_kwargs.get("is_teardown"))
438 @is_teardown.setter
439 def is_teardown(self, value: bool) -> None:
440 self.partial_kwargs["is_teardown"] = value
442 @property
443 def depends_on_past(self) -> bool:
444 return bool(self.partial_kwargs.get("depends_on_past"))
446 @depends_on_past.setter
447 def depends_on_past(self, value: bool) -> None:
448 self.partial_kwargs["depends_on_past"] = value
450 @property
451 def ignore_first_depends_on_past(self) -> bool:
452 value = self.partial_kwargs.get("ignore_first_depends_on_past", DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST)
453 return bool(value)
455 @ignore_first_depends_on_past.setter
456 def ignore_first_depends_on_past(self, value: bool) -> None:
457 self.partial_kwargs["ignore_first_depends_on_past"] = value
459 @property
460 def wait_for_past_depends_before_skipping(self) -> bool:
461 value = self.partial_kwargs.get(
462 "wait_for_past_depends_before_skipping", DEFAULT_WAIT_FOR_PAST_DEPENDS_BEFORE_SKIPPING
463 )
464 return bool(value)
466 @wait_for_past_depends_before_skipping.setter
467 def wait_for_past_depends_before_skipping(self, value: bool) -> None:
468 self.partial_kwargs["wait_for_past_depends_before_skipping"] = value
470 @property
471 def wait_for_downstream(self) -> bool:
472 return bool(self.partial_kwargs.get("wait_for_downstream"))
474 @wait_for_downstream.setter
475 def wait_for_downstream(self, value: bool) -> None:
476 self.partial_kwargs["wait_for_downstream"] = value
478 @property
479 def retries(self) -> int:
480 return self.partial_kwargs.get("retries", DEFAULT_RETRIES)
482 @retries.setter
483 def retries(self, value: int) -> None:
484 self.partial_kwargs["retries"] = value
486 @property
487 def queue(self) -> str:
488 return self.partial_kwargs.get("queue", DEFAULT_QUEUE)
490 @queue.setter
491 def queue(self, value: str) -> None:
492 self.partial_kwargs["queue"] = value
494 @property
495 def pool(self) -> str:
496 return self.partial_kwargs.get("pool", Pool.DEFAULT_POOL_NAME)
498 @pool.setter
499 def pool(self, value: str) -> None:
500 self.partial_kwargs["pool"] = value
502 @property
503 def pool_slots(self) -> int:
504 return self.partial_kwargs.get("pool_slots", DEFAULT_POOL_SLOTS)
506 @pool_slots.setter
507 def pool_slots(self, value: int) -> None:
508 self.partial_kwargs["pool_slots"] = value
510 @property
511 def execution_timeout(self) -> datetime.timedelta | None:
512 return self.partial_kwargs.get("execution_timeout")
514 @execution_timeout.setter
515 def execution_timeout(self, value: datetime.timedelta | None) -> None:
516 self.partial_kwargs["execution_timeout"] = value
518 @property
519 def max_retry_delay(self) -> datetime.timedelta | None:
520 return self.partial_kwargs.get("max_retry_delay")
522 @max_retry_delay.setter
523 def max_retry_delay(self, value: datetime.timedelta | None) -> None:
524 self.partial_kwargs["max_retry_delay"] = value
526 @property
527 def retry_delay(self) -> datetime.timedelta:
528 return self.partial_kwargs.get("retry_delay", DEFAULT_RETRY_DELAY)
530 @retry_delay.setter
531 def retry_delay(self, value: datetime.timedelta) -> None:
532 self.partial_kwargs["retry_delay"] = value
534 @property
535 def retry_exponential_backoff(self) -> bool:
536 return bool(self.partial_kwargs.get("retry_exponential_backoff"))
538 @retry_exponential_backoff.setter
539 def retry_exponential_backoff(self, value: bool) -> None:
540 self.partial_kwargs["retry_exponential_backoff"] = value
542 @property
543 def priority_weight(self) -> int: # type: ignore[override]
544 return self.partial_kwargs.get("priority_weight", DEFAULT_PRIORITY_WEIGHT)
546 @priority_weight.setter
547 def priority_weight(self, value: int) -> None:
548 self.partial_kwargs["priority_weight"] = value
550 @property
551 def weight_rule(self) -> PriorityWeightStrategy: # type: ignore[override]
552 return validate_and_load_priority_weight_strategy(
553 self.partial_kwargs.get("weight_rule", DEFAULT_WEIGHT_RULE)
554 )
556 @weight_rule.setter
557 def weight_rule(self, value: str | PriorityWeightStrategy) -> None:
558 self.partial_kwargs["weight_rule"] = validate_and_load_priority_weight_strategy(value)
560 @property
561 def sla(self) -> datetime.timedelta | None:
562 return self.partial_kwargs.get("sla")
564 @sla.setter
565 def sla(self, value: datetime.timedelta | None) -> None:
566 self.partial_kwargs["sla"] = value
568 @property
569 def max_active_tis_per_dag(self) -> int | None:
570 return self.partial_kwargs.get("max_active_tis_per_dag")
572 @max_active_tis_per_dag.setter
573 def max_active_tis_per_dag(self, value: int | None) -> None:
574 self.partial_kwargs["max_active_tis_per_dag"] = value
576 @property
577 def max_active_tis_per_dagrun(self) -> int | None:
578 return self.partial_kwargs.get("max_active_tis_per_dagrun")
580 @max_active_tis_per_dagrun.setter
581 def max_active_tis_per_dagrun(self, value: int | None) -> None:
582 self.partial_kwargs["max_active_tis_per_dagrun"] = value
584 @property
585 def resources(self) -> Resources | None:
586 return self.partial_kwargs.get("resources")
588 @property
589 def on_execute_callback(self) -> TaskStateChangeCallbackAttrType:
590 return self.partial_kwargs.get("on_execute_callback")
592 @on_execute_callback.setter
593 def on_execute_callback(self, value: TaskStateChangeCallbackAttrType) -> None:
594 self.partial_kwargs["on_execute_callback"] = value
596 @property
597 def on_failure_callback(self) -> TaskStateChangeCallbackAttrType:
598 return self.partial_kwargs.get("on_failure_callback")
600 @on_failure_callback.setter
601 def on_failure_callback(self, value: TaskStateChangeCallbackAttrType) -> None:
602 self.partial_kwargs["on_failure_callback"] = value
604 @property
605 def on_retry_callback(self) -> TaskStateChangeCallbackAttrType:
606 return self.partial_kwargs.get("on_retry_callback")
608 @on_retry_callback.setter
609 def on_retry_callback(self, value: TaskStateChangeCallbackAttrType) -> None:
610 self.partial_kwargs["on_retry_callback"] = value
612 @property
613 def on_success_callback(self) -> TaskStateChangeCallbackAttrType:
614 return self.partial_kwargs.get("on_success_callback")
616 @on_success_callback.setter
617 def on_success_callback(self, value: TaskStateChangeCallbackAttrType) -> None:
618 self.partial_kwargs["on_success_callback"] = value
620 @property
621 def on_skipped_callback(self) -> TaskStateChangeCallbackAttrType:
622 return self.partial_kwargs.get("on_skipped_callback")
624 @on_skipped_callback.setter
625 def on_skipped_callback(self, value: TaskStateChangeCallbackAttrType) -> None:
626 self.partial_kwargs["on_skipped_callback"] = value
628 @property
629 def run_as_user(self) -> str | None:
630 return self.partial_kwargs.get("run_as_user")
632 @property
633 def executor(self) -> str | None:
634 return self.partial_kwargs.get("executor", DEFAULT_EXECUTOR)
636 @property
637 def executor_config(self) -> dict:
638 return self.partial_kwargs.get("executor_config", {})
640 @property # type: ignore[override]
641 def inlets(self) -> list[Any]: # type: ignore[override]
642 return self.partial_kwargs.get("inlets", [])
644 @inlets.setter
645 def inlets(self, value: list[Any]) -> None: # type: ignore[override]
646 self.partial_kwargs["inlets"] = value
648 @property # type: ignore[override]
649 def outlets(self) -> list[Any]: # type: ignore[override]
650 return self.partial_kwargs.get("outlets", [])
652 @outlets.setter
653 def outlets(self, value: list[Any]) -> None: # type: ignore[override]
654 self.partial_kwargs["outlets"] = value
656 @property
657 def doc(self) -> str | None:
658 return self.partial_kwargs.get("doc")
660 @property
661 def doc_md(self) -> str | None:
662 return self.partial_kwargs.get("doc_md")
664 @property
665 def doc_json(self) -> str | None:
666 return self.partial_kwargs.get("doc_json")
668 @property
669 def doc_yaml(self) -> str | None:
670 return self.partial_kwargs.get("doc_yaml")
672 @property
673 def doc_rst(self) -> str | None:
674 return self.partial_kwargs.get("doc_rst")
676 @property
677 def allow_nested_operators(self) -> bool:
678 return bool(self.partial_kwargs.get("allow_nested_operators"))
680 def get_dag(self) -> DAG | None:
681 """Implement Operator."""
682 return self.dag
684 @property
685 def output(self) -> XComArg:
686 """Return reference to XCom pushed by current operator."""
687 from airflow.models.xcom_arg import XComArg
689 return XComArg(operator=self)
691 def serialize_for_task_group(self) -> tuple[DagAttributeTypes, Any]:
692 """Implement DAGNode."""
693 return DagAttributeTypes.OP, self.task_id
695 def _expand_mapped_kwargs(self, context: Context, session: Session) -> tuple[Mapping[str, Any], set[int]]:
696 """Get the kwargs to create the unmapped operator.
698 This exists because taskflow operators expand against op_kwargs, not the
699 entire operator kwargs dict.
700 """
701 return self._get_specified_expand_input().resolve(context, session)
703 def _get_unmap_kwargs(self, mapped_kwargs: Mapping[str, Any], *, strict: bool) -> dict[str, Any]:
704 """Get init kwargs to unmap the underlying operator class.
706 :param mapped_kwargs: The dict returned by ``_expand_mapped_kwargs``.
707 """
708 if strict:
709 prevent_duplicates(
710 self.partial_kwargs,
711 mapped_kwargs,
712 fail_reason="unmappable or already specified",
713 )
715 # If params appears in the mapped kwargs, we need to merge it into the
716 # partial params, overriding existing keys.
717 params = copy.copy(self.params)
718 with contextlib.suppress(KeyError):
719 params.update(mapped_kwargs["params"])
721 # Ordering is significant; mapped kwargs should override partial ones,
722 # and the specially handled params should be respected.
723 return {
724 "task_id": self.task_id,
725 "dag": self.dag,
726 "task_group": self.task_group,
727 "start_date": self.start_date,
728 "end_date": self.end_date,
729 **self.partial_kwargs,
730 **mapped_kwargs,
731 "params": params,
732 }
734 def unmap(self, resolve: None | Mapping[str, Any] | tuple[Context, Session]) -> BaseOperator:
735 """Get the "normal" Operator after applying the current mapping.
737 The *resolve* argument is only used if ``operator_class`` is a real
738 class, i.e. if this operator is not serialized. If ``operator_class`` is
739 not a class (i.e. this DAG has been deserialized), this returns a
740 SerializedBaseOperator that "looks like" the actual unmapping result.
742 If *resolve* is a two-tuple (context, session), the information is used
743 to resolve the mapped arguments into init arguments. If it is a mapping,
744 no resolving happens, the mapping directly provides those init arguments
745 resolved from mapped kwargs.
747 :meta private:
748 """
749 if isinstance(self.operator_class, type):
750 if isinstance(resolve, collections.abc.Mapping):
751 kwargs = resolve
752 elif resolve is not None:
753 kwargs, _ = self._expand_mapped_kwargs(*resolve)
754 else:
755 raise RuntimeError("cannot unmap a non-serialized operator without context")
756 kwargs = self._get_unmap_kwargs(kwargs, strict=self._disallow_kwargs_override)
757 is_setup = kwargs.pop("is_setup", False)
758 is_teardown = kwargs.pop("is_teardown", False)
759 on_failure_fail_dagrun = kwargs.pop("on_failure_fail_dagrun", False)
760 op = self.operator_class(**kwargs, _airflow_from_mapped=True)
761 # We need to overwrite task_id here because BaseOperator further
762 # mangles the task_id based on the task hierarchy (namely, group_id
763 # is prepended, and '__N' appended to deduplicate). This is hacky,
764 # but better than duplicating the whole mangling logic.
765 op.task_id = self.task_id
766 op.is_setup = is_setup
767 op.is_teardown = is_teardown
768 op.on_failure_fail_dagrun = on_failure_fail_dagrun
769 return op
771 # After a mapped operator is serialized, there's no real way to actually
772 # unmap it since we've lost access to the underlying operator class.
773 # This tries its best to simply "forward" all the attributes on this
774 # mapped operator to a new SerializedBaseOperator instance.
775 from airflow.serialization.serialized_objects import SerializedBaseOperator
777 op = SerializedBaseOperator(task_id=self.task_id, params=self.params, _airflow_from_mapped=True)
778 SerializedBaseOperator.populate_operator(op, self.operator_class)
779 if self.dag is not None: # For Mypy; we only serialize tasks in a DAG so the check always satisfies.
780 SerializedBaseOperator.set_task_dag_references(op, self.dag)
781 return op
783 def _get_specified_expand_input(self) -> ExpandInput:
784 """Input received from the expand call on the operator."""
785 return getattr(self, self._expand_input_attr)
787 def prepare_for_execution(self) -> MappedOperator:
788 # Since a mapped operator cannot be used for execution, and an unmapped
789 # BaseOperator needs to be created later (see render_template_fields),
790 # we don't need to create a copy of the MappedOperator here.
791 return self
793 def iter_mapped_dependencies(self) -> Iterator[Operator]:
794 """Upstream dependencies that provide XComs used by this task for task mapping."""
795 from airflow.models.xcom_arg import XComArg
797 for operator, _ in XComArg.iter_xcom_references(self._get_specified_expand_input()):
798 yield operator
800 @methodtools.lru_cache(maxsize=None)
801 def get_parse_time_mapped_ti_count(self) -> int:
802 current_count = self._get_specified_expand_input().get_parse_time_mapped_ti_count()
803 try:
804 parent_count = super().get_parse_time_mapped_ti_count()
805 except NotMapped:
806 return current_count
807 return parent_count * current_count
809 def get_mapped_ti_count(self, run_id: str, *, session: Session) -> int:
810 current_count = self._get_specified_expand_input().get_total_map_length(run_id, session=session)
811 try:
812 parent_count = super().get_mapped_ti_count(run_id, session=session)
813 except NotMapped:
814 return current_count
815 return parent_count * current_count
817 def render_template_fields(
818 self,
819 context: Context,
820 jinja_env: jinja2.Environment | None = None,
821 ) -> None:
822 """Template all attributes listed in *self.template_fields*.
824 This updates *context* to reference the map-expanded task and relevant
825 information, without modifying the mapped operator. The expanded task
826 in *context* is then rendered in-place.
828 :param context: Context dict with values to apply on content.
829 :param jinja_env: Jinja environment to use for rendering.
830 """
831 if not jinja_env:
832 jinja_env = self.get_template_env()
834 # We retrieve the session here, stored by _run_raw_task in set_current_task_session
835 # context manager - we cannot pass the session via @provide_session because the signature
836 # of render_template_fields is defined by BaseOperator and there are already many subclasses
837 # overriding it, so changing the signature is not an option. However render_template_fields is
838 # always executed within "_run_raw_task" so we make sure that _run_raw_task uses the
839 # set_current_task_session context manager to store the session in the current task.
840 session = get_current_task_instance_session()
842 mapped_kwargs, seen_oids = self._expand_mapped_kwargs(context, session)
843 unmapped_task = self.unmap(mapped_kwargs)
844 context_update_for_unmapped(context, unmapped_task)
846 # Since the operators that extend `BaseOperator` are not subclasses of
847 # `MappedOperator`, we need to call `_do_render_template_fields` from
848 # the unmapped task in order to call the operator method when we override
849 # it to customize the parsing of nested fields.
850 unmapped_task._do_render_template_fields(
851 parent=unmapped_task,
852 template_fields=self.template_fields,
853 context=context,
854 jinja_env=jinja_env,
855 seen_oids=seen_oids,
856 )