Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/airflow/models/mappedoperator.py: 49%
389 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-07 06:35 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-07 06:35 +0000
1#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18from __future__ import annotations
20import collections
21import collections.abc
22import contextlib
23import copy
24import datetime
25import warnings
26from typing import TYPE_CHECKING, Any, ClassVar, Collection, Iterable, Iterator, Mapping, Sequence, Union
28import attr
29import pendulum
30from sqlalchemy.orm.session import Session
32from airflow import settings
33from airflow.compat.functools import cache
34from airflow.exceptions import AirflowException, UnmappableOperator
35from airflow.models.abstractoperator import (
36 DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST,
37 DEFAULT_OWNER,
38 DEFAULT_POOL_SLOTS,
39 DEFAULT_PRIORITY_WEIGHT,
40 DEFAULT_QUEUE,
41 DEFAULT_RETRIES,
42 DEFAULT_RETRY_DELAY,
43 DEFAULT_TRIGGER_RULE,
44 DEFAULT_WAIT_FOR_PAST_DEPENDS_BEFORE_SKIPPING,
45 DEFAULT_WEIGHT_RULE,
46 AbstractOperator,
47 NotMapped,
48 TaskStateChangeCallback,
49)
50from airflow.models.expandinput import (
51 DictOfListsExpandInput,
52 ExpandInput,
53 ListOfDictsExpandInput,
54 OperatorExpandArgument,
55 OperatorExpandKwargsArgument,
56 is_mappable,
57)
58from airflow.models.param import ParamsDict
59from airflow.models.pool import Pool
60from airflow.serialization.enums import DagAttributeTypes
61from airflow.ti_deps.deps.base_ti_dep import BaseTIDep
62from airflow.ti_deps.deps.mapped_task_expanded import MappedTaskIsExpanded
63from airflow.typing_compat import Literal
64from airflow.utils.context import Context, context_update_for_unmapped
65from airflow.utils.helpers import is_container, prevent_duplicates
66from airflow.utils.operator_resources import Resources
67from airflow.utils.trigger_rule import TriggerRule
68from airflow.utils.types import NOTSET
69from airflow.utils.xcom import XCOM_RETURN_KEY
71if TYPE_CHECKING:
72 import jinja2 # Slow import.
74 from airflow.models.baseoperator import BaseOperator, BaseOperatorLink
75 from airflow.models.dag import DAG
76 from airflow.models.operator import Operator
77 from airflow.models.xcom_arg import XComArg
78 from airflow.utils.task_group import TaskGroup
80ValidationSource = Union[Literal["expand"], Literal["partial"]]
83def validate_mapping_kwargs(op: type[BaseOperator], func: ValidationSource, value: dict[str, Any]) -> None:
84 # use a dict so order of args is same as code order
85 unknown_args = value.copy()
86 for klass in op.mro():
87 init = klass.__init__ # type: ignore[misc]
88 try:
89 param_names = init._BaseOperatorMeta__param_names
90 except AttributeError:
91 continue
92 for name in param_names:
93 value = unknown_args.pop(name, NOTSET)
94 if func != "expand":
95 continue
96 if value is NOTSET:
97 continue
98 if is_mappable(value):
99 continue
100 type_name = type(value).__name__
101 error = f"{op.__name__}.expand() got an unexpected type {type_name!r} for keyword argument {name}"
102 raise ValueError(error)
103 if not unknown_args:
104 return # If we have no args left to check: stop looking at the MRO chain.
106 if len(unknown_args) == 1:
107 error = f"an unexpected keyword argument {unknown_args.popitem()[0]!r}"
108 else:
109 names = ", ".join(repr(n) for n in unknown_args)
110 error = f"unexpected keyword arguments {names}"
111 raise TypeError(f"{op.__name__}.{func}() got {error}")
114def ensure_xcomarg_return_value(arg: Any) -> None:
115 from airflow.models.xcom_arg import XComArg
117 if isinstance(arg, XComArg):
118 for operator, key in arg.iter_references():
119 if key != XCOM_RETURN_KEY:
120 raise ValueError(f"cannot map over XCom with custom key {key!r} from {operator}")
121 elif not is_container(arg):
122 return
123 elif isinstance(arg, collections.abc.Mapping):
124 for v in arg.values():
125 ensure_xcomarg_return_value(v)
126 elif isinstance(arg, collections.abc.Iterable):
127 for v in arg:
128 ensure_xcomarg_return_value(v)
131@attr.define(kw_only=True, repr=False)
132class OperatorPartial:
133 """An "intermediate state" returned by ``BaseOperator.partial()``.
135 This only exists at DAG-parsing time; the only intended usage is for the
136 user to call ``.expand()`` on it at some point (usually in a method chain) to
137 create a ``MappedOperator`` to add into the DAG.
138 """
140 operator_class: type[BaseOperator]
141 kwargs: dict[str, Any]
142 params: ParamsDict | dict
144 _expand_called: bool = False # Set when expand() is called to ease user debugging.
146 def __attrs_post_init__(self):
147 from airflow.operators.subdag import SubDagOperator
149 if issubclass(self.operator_class, SubDagOperator):
150 raise TypeError("Mapping over deprecated SubDagOperator is not supported")
151 validate_mapping_kwargs(self.operator_class, "partial", self.kwargs)
153 def __repr__(self) -> str:
154 args = ", ".join(f"{k}={v!r}" for k, v in self.kwargs.items())
155 return f"{self.operator_class.__name__}.partial({args})"
157 def __del__(self):
158 if not self._expand_called:
159 try:
160 task_id = repr(self.kwargs["task_id"])
161 except KeyError:
162 task_id = f"at {hex(id(self))}"
163 warnings.warn(f"Task {task_id} was never mapped!")
165 def expand(self, **mapped_kwargs: OperatorExpandArgument) -> MappedOperator:
166 if not mapped_kwargs:
167 raise TypeError("no arguments to expand against")
168 validate_mapping_kwargs(self.operator_class, "expand", mapped_kwargs)
169 prevent_duplicates(self.kwargs, mapped_kwargs, fail_reason="unmappable or already specified")
170 # Since the input is already checked at parse time, we can set strict
171 # to False to skip the checks on execution.
172 return self._expand(DictOfListsExpandInput(mapped_kwargs), strict=False)
174 def expand_kwargs(self, kwargs: OperatorExpandKwargsArgument, *, strict: bool = True) -> MappedOperator:
175 from airflow.models.xcom_arg import XComArg
177 if isinstance(kwargs, collections.abc.Sequence):
178 for item in kwargs:
179 if not isinstance(item, (XComArg, collections.abc.Mapping)):
180 raise TypeError(f"expected XComArg or list[dict], not {type(kwargs).__name__}")
181 elif not isinstance(kwargs, XComArg):
182 raise TypeError(f"expected XComArg or list[dict], not {type(kwargs).__name__}")
183 return self._expand(ListOfDictsExpandInput(kwargs), strict=strict)
185 def _expand(self, expand_input: ExpandInput, *, strict: bool) -> MappedOperator:
186 from airflow.operators.empty import EmptyOperator
188 self._expand_called = True
189 ensure_xcomarg_return_value(expand_input.value)
191 partial_kwargs = self.kwargs.copy()
192 task_id = partial_kwargs.pop("task_id")
193 dag = partial_kwargs.pop("dag")
194 task_group = partial_kwargs.pop("task_group")
195 start_date = partial_kwargs.pop("start_date")
196 end_date = partial_kwargs.pop("end_date")
198 try:
199 operator_name = self.operator_class.custom_operator_name # type: ignore
200 except AttributeError:
201 operator_name = self.operator_class.__name__
203 op = MappedOperator(
204 operator_class=self.operator_class,
205 expand_input=expand_input,
206 partial_kwargs=partial_kwargs,
207 task_id=task_id,
208 params=self.params,
209 deps=MappedOperator.deps_for(self.operator_class),
210 operator_extra_links=self.operator_class.operator_extra_links,
211 template_ext=self.operator_class.template_ext,
212 template_fields=self.operator_class.template_fields,
213 template_fields_renderers=self.operator_class.template_fields_renderers,
214 ui_color=self.operator_class.ui_color,
215 ui_fgcolor=self.operator_class.ui_fgcolor,
216 is_empty=issubclass(self.operator_class, EmptyOperator),
217 task_module=self.operator_class.__module__,
218 task_type=self.operator_class.__name__,
219 operator_name=operator_name,
220 dag=dag,
221 task_group=task_group,
222 start_date=start_date,
223 end_date=end_date,
224 disallow_kwargs_override=strict,
225 # For classic operators, this points to expand_input because kwargs
226 # to BaseOperator.expand() contribute to operator arguments.
227 expand_input_attr="expand_input",
228 )
229 return op
232@attr.define(
233 kw_only=True,
234 # Disable custom __getstate__ and __setstate__ generation since it interacts
235 # badly with Airflow's DAG serialization and pickling. When a mapped task is
236 # deserialized, subclasses are coerced into MappedOperator, but when it goes
237 # through DAG pickling, all attributes defined in the subclasses are dropped
238 # by attrs's custom state management. Since attrs does not do anything too
239 # special here (the logic is only important for slots=True), we use Python's
240 # built-in implementation, which works (as proven by good old BaseOperator).
241 getstate_setstate=False,
242)
243class MappedOperator(AbstractOperator):
244 """Object representing a mapped operator in a DAG."""
246 # This attribute serves double purpose. For a "normal" operator instance
247 # loaded from DAG, this holds the underlying non-mapped operator class that
248 # can be used to create an unmapped operator for execution. For an operator
249 # recreated from a serialized DAG, however, this holds the serialized data
250 # that can be used to unmap this into a SerializedBaseOperator.
251 operator_class: type[BaseOperator] | dict[str, Any]
253 expand_input: ExpandInput
254 partial_kwargs: dict[str, Any]
256 # Needed for serialization.
257 task_id: str
258 params: ParamsDict | dict
259 deps: frozenset[BaseTIDep]
260 operator_extra_links: Collection[BaseOperatorLink]
261 template_ext: Sequence[str]
262 template_fields: Collection[str]
263 template_fields_renderers: dict[str, str]
264 ui_color: str
265 ui_fgcolor: str
266 _is_empty: bool
267 _task_module: str
268 _task_type: str
269 _operator_name: str
271 dag: DAG | None
272 task_group: TaskGroup | None
273 start_date: pendulum.DateTime | None
274 end_date: pendulum.DateTime | None
275 upstream_task_ids: set[str] = attr.ib(factory=set, init=False)
276 downstream_task_ids: set[str] = attr.ib(factory=set, init=False)
278 _disallow_kwargs_override: bool
279 """Whether execution fails if ``expand_input`` has duplicates to ``partial_kwargs``.
281 If *False*, values from ``expand_input`` under duplicate keys override those
282 under corresponding keys in ``partial_kwargs``.
283 """
285 _expand_input_attr: str
286 """Where to get kwargs to calculate expansion length against.
288 This should be a name to call ``getattr()`` on.
289 """
291 subdag: None = None # Since we don't support SubDagOperator, this is always None.
292 supports_lineage: bool = False
293 is_setup: bool = False
294 is_teardown: bool = False
295 on_failure_fail_dagrun: bool = False
297 HIDE_ATTRS_FROM_UI: ClassVar[frozenset[str]] = AbstractOperator.HIDE_ATTRS_FROM_UI | frozenset(
298 (
299 "parse_time_mapped_ti_count",
300 "operator_class",
301 )
302 )
304 def __hash__(self):
305 return id(self)
307 def __repr__(self):
308 return f"<Mapped({self._task_type}): {self.task_id}>"
310 def __attrs_post_init__(self):
311 from airflow.models.xcom_arg import XComArg
313 if self.get_closest_mapped_task_group() is not None:
314 raise NotImplementedError("operator expansion in an expanded task group is not yet supported")
316 if self.task_group:
317 self.task_group.add(self)
318 if self.dag:
319 self.dag.add_task(self)
320 XComArg.apply_upstream_relationship(self, self.expand_input.value)
321 for k, v in self.partial_kwargs.items():
322 if k in self.template_fields:
323 XComArg.apply_upstream_relationship(self, v)
324 if self.partial_kwargs.get("sla") is not None:
325 raise AirflowException(
326 f"SLAs are unsupported with mapped tasks. Please set `sla=None` for task "
327 f"{self.task_id!r}."
328 )
330 @classmethod
331 @cache
332 def get_serialized_fields(cls):
333 # Not using 'cls' here since we only want to serialize base fields.
334 return frozenset(attr.fields_dict(MappedOperator)) - {
335 "dag",
336 "deps",
337 "expand_input", # This is needed to be able to accept XComArg.
338 "subdag",
339 "task_group",
340 "upstream_task_ids",
341 "supports_lineage",
342 "is_setup",
343 "is_teardown",
344 "on_failure_fail_dagrun",
345 }
347 @staticmethod
348 @cache
349 def deps_for(operator_class: type[BaseOperator]) -> frozenset[BaseTIDep]:
350 operator_deps = operator_class.deps
351 if not isinstance(operator_deps, collections.abc.Set):
352 raise UnmappableOperator(
353 f"'deps' must be a set defined as a class-level variable on {operator_class.__name__}, "
354 f"not a {type(operator_deps).__name__}"
355 )
356 return operator_deps | {MappedTaskIsExpanded()}
358 @property
359 def task_type(self) -> str:
360 """Implementing Operator."""
361 return self._task_type
363 @property
364 def operator_name(self) -> str:
365 return self._operator_name
367 @property
368 def inherits_from_empty_operator(self) -> bool:
369 """Implementing Operator."""
370 return self._is_empty
372 @property
373 def roots(self) -> Sequence[AbstractOperator]:
374 """Implementing DAGNode."""
375 return [self]
377 @property
378 def leaves(self) -> Sequence[AbstractOperator]:
379 """Implementing DAGNode."""
380 return [self]
382 @property
383 def owner(self) -> str: # type: ignore[override]
384 return self.partial_kwargs.get("owner", DEFAULT_OWNER)
386 @property
387 def email(self) -> None | str | Iterable[str]:
388 return self.partial_kwargs.get("email")
390 @property
391 def trigger_rule(self) -> TriggerRule:
392 return self.partial_kwargs.get("trigger_rule", DEFAULT_TRIGGER_RULE)
394 @property
395 def depends_on_past(self) -> bool:
396 return bool(self.partial_kwargs.get("depends_on_past"))
398 @property
399 def ignore_first_depends_on_past(self) -> bool:
400 value = self.partial_kwargs.get("ignore_first_depends_on_past", DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST)
401 return bool(value)
403 @property
404 def wait_for_past_depends_before_skipping(self) -> bool:
405 value = self.partial_kwargs.get(
406 "wait_for_past_depends_before_skipping", DEFAULT_WAIT_FOR_PAST_DEPENDS_BEFORE_SKIPPING
407 )
408 return bool(value)
410 @property
411 def wait_for_downstream(self) -> bool:
412 return bool(self.partial_kwargs.get("wait_for_downstream"))
414 @property
415 def retries(self) -> int | None:
416 return self.partial_kwargs.get("retries", DEFAULT_RETRIES)
418 @property
419 def queue(self) -> str:
420 return self.partial_kwargs.get("queue", DEFAULT_QUEUE)
422 @property
423 def pool(self) -> str:
424 return self.partial_kwargs.get("pool", Pool.DEFAULT_POOL_NAME)
426 @property
427 def pool_slots(self) -> str | None:
428 return self.partial_kwargs.get("pool_slots", DEFAULT_POOL_SLOTS)
430 @property
431 def execution_timeout(self) -> datetime.timedelta | None:
432 return self.partial_kwargs.get("execution_timeout")
434 @property
435 def max_retry_delay(self) -> datetime.timedelta | None:
436 return self.partial_kwargs.get("max_retry_delay")
438 @property
439 def retry_delay(self) -> datetime.timedelta:
440 return self.partial_kwargs.get("retry_delay", DEFAULT_RETRY_DELAY)
442 @property
443 def retry_exponential_backoff(self) -> bool:
444 return bool(self.partial_kwargs.get("retry_exponential_backoff"))
446 @property
447 def priority_weight(self) -> int: # type: ignore[override]
448 return self.partial_kwargs.get("priority_weight", DEFAULT_PRIORITY_WEIGHT)
450 @property
451 def weight_rule(self) -> int: # type: ignore[override]
452 return self.partial_kwargs.get("weight_rule", DEFAULT_WEIGHT_RULE)
454 @property
455 def sla(self) -> datetime.timedelta | None:
456 return self.partial_kwargs.get("sla")
458 @property
459 def max_active_tis_per_dag(self) -> int | None:
460 return self.partial_kwargs.get("max_active_tis_per_dag")
462 @property
463 def max_active_tis_per_dagrun(self) -> int | None:
464 return self.partial_kwargs.get("max_active_tis_per_dagrun")
466 @property
467 def resources(self) -> Resources | None:
468 return self.partial_kwargs.get("resources")
470 @property
471 def on_execute_callback(self) -> None | TaskStateChangeCallback | list[TaskStateChangeCallback]:
472 return self.partial_kwargs.get("on_execute_callback")
474 @on_execute_callback.setter
475 def on_execute_callback(self, value: TaskStateChangeCallback | None) -> None:
476 self.partial_kwargs["on_execute_callback"] = value
478 @property
479 def on_failure_callback(self) -> None | TaskStateChangeCallback | list[TaskStateChangeCallback]:
480 return self.partial_kwargs.get("on_failure_callback")
482 @on_failure_callback.setter
483 def on_failure_callback(self, value: TaskStateChangeCallback | None) -> None:
484 self.partial_kwargs["on_failure_callback"] = value
486 @property
487 def on_retry_callback(self) -> None | TaskStateChangeCallback | list[TaskStateChangeCallback]:
488 return self.partial_kwargs.get("on_retry_callback")
490 @on_retry_callback.setter
491 def on_retry_callback(self, value: TaskStateChangeCallback | None) -> None:
492 self.partial_kwargs["on_retry_callback"] = value
494 @property
495 def on_success_callback(self) -> None | TaskStateChangeCallback | list[TaskStateChangeCallback]:
496 return self.partial_kwargs.get("on_success_callback")
498 @on_success_callback.setter
499 def on_success_callback(self, value: TaskStateChangeCallback | None) -> None:
500 self.partial_kwargs["on_success_callback"] = value
502 @property
503 def run_as_user(self) -> str | None:
504 return self.partial_kwargs.get("run_as_user")
506 @property
507 def executor_config(self) -> dict:
508 return self.partial_kwargs.get("executor_config", {})
510 @property # type: ignore[override]
511 def inlets(self) -> list[Any]: # type: ignore[override]
512 return self.partial_kwargs.get("inlets", [])
514 @inlets.setter
515 def inlets(self, value: list[Any]) -> None: # type: ignore[override]
516 self.partial_kwargs["inlets"] = value
518 @property # type: ignore[override]
519 def outlets(self) -> list[Any]: # type: ignore[override]
520 return self.partial_kwargs.get("outlets", [])
522 @outlets.setter
523 def outlets(self, value: list[Any]) -> None: # type: ignore[override]
524 self.partial_kwargs["outlets"] = value
526 @property
527 def doc(self) -> str | None:
528 return self.partial_kwargs.get("doc")
530 @property
531 def doc_md(self) -> str | None:
532 return self.partial_kwargs.get("doc_md")
534 @property
535 def doc_json(self) -> str | None:
536 return self.partial_kwargs.get("doc_json")
538 @property
539 def doc_yaml(self) -> str | None:
540 return self.partial_kwargs.get("doc_yaml")
542 @property
543 def doc_rst(self) -> str | None:
544 return self.partial_kwargs.get("doc_rst")
546 def get_dag(self) -> DAG | None:
547 """Implementing Operator."""
548 return self.dag
550 @property
551 def output(self) -> XComArg:
552 """Returns reference to XCom pushed by current operator."""
553 from airflow.models.xcom_arg import XComArg
555 return XComArg(operator=self)
557 def serialize_for_task_group(self) -> tuple[DagAttributeTypes, Any]:
558 """Implementing DAGNode."""
559 return DagAttributeTypes.OP, self.task_id
561 def _expand_mapped_kwargs(self, context: Context, session: Session) -> tuple[Mapping[str, Any], set[int]]:
562 """Get the kwargs to create the unmapped operator.
564 This exists because taskflow operators expand against op_kwargs, not the
565 entire operator kwargs dict.
566 """
567 return self._get_specified_expand_input().resolve(context, session)
569 def _get_unmap_kwargs(self, mapped_kwargs: Mapping[str, Any], *, strict: bool) -> dict[str, Any]:
570 """Get init kwargs to unmap the underlying operator class.
572 :param mapped_kwargs: The dict returned by ``_expand_mapped_kwargs``.
573 """
574 if strict:
575 prevent_duplicates(
576 self.partial_kwargs,
577 mapped_kwargs,
578 fail_reason="unmappable or already specified",
579 )
581 # If params appears in the mapped kwargs, we need to merge it into the
582 # partial params, overriding existing keys.
583 params = copy.copy(self.params)
584 with contextlib.suppress(KeyError):
585 params.update(mapped_kwargs["params"])
587 # Ordering is significant; mapped kwargs should override partial ones,
588 # and the specially handled params should be respected.
589 return {
590 "task_id": self.task_id,
591 "dag": self.dag,
592 "task_group": self.task_group,
593 "start_date": self.start_date,
594 "end_date": self.end_date,
595 **self.partial_kwargs,
596 **mapped_kwargs,
597 "params": params,
598 }
600 def unmap(self, resolve: None | Mapping[str, Any] | tuple[Context, Session]) -> BaseOperator:
601 """Get the "normal" Operator after applying the current mapping.
603 The *resolve* argument is only used if ``operator_class`` is a real
604 class, i.e. if this operator is not serialized. If ``operator_class`` is
605 not a class (i.e. this DAG has been deserialized), this returns a
606 SerializedBaseOperator that "looks like" the actual unmapping result.
608 If *resolve* is a two-tuple (context, session), the information is used
609 to resolve the mapped arguments into init arguments. If it is a mapping,
610 no resolving happens, the mapping directly provides those init arguments
611 resolved from mapped kwargs.
613 :meta private:
614 """
615 if isinstance(self.operator_class, type):
616 if isinstance(resolve, collections.abc.Mapping):
617 kwargs = resolve
618 elif resolve is not None:
619 kwargs, _ = self._expand_mapped_kwargs(*resolve)
620 else:
621 raise RuntimeError("cannot unmap a non-serialized operator without context")
622 kwargs = self._get_unmap_kwargs(kwargs, strict=self._disallow_kwargs_override)
623 op = self.operator_class(**kwargs, _airflow_from_mapped=True)
624 # We need to overwrite task_id here because BaseOperator further
625 # mangles the task_id based on the task hierarchy (namely, group_id
626 # is prepended, and '__N' appended to deduplicate). This is hacky,
627 # but better than duplicating the whole mangling logic.
628 op.task_id = self.task_id
629 return op
631 # After a mapped operator is serialized, there's no real way to actually
632 # unmap it since we've lost access to the underlying operator class.
633 # This tries its best to simply "forward" all the attributes on this
634 # mapped operator to a new SerializedBaseOperator instance.
635 from airflow.serialization.serialized_objects import SerializedBaseOperator
637 op = SerializedBaseOperator(task_id=self.task_id, params=self.params, _airflow_from_mapped=True)
638 SerializedBaseOperator.populate_operator(op, self.operator_class)
639 return op
641 def _get_specified_expand_input(self) -> ExpandInput:
642 """Input received from the expand call on the operator."""
643 return getattr(self, self._expand_input_attr)
645 def prepare_for_execution(self) -> MappedOperator:
646 # Since a mapped operator cannot be used for execution, and an unmapped
647 # BaseOperator needs to be created later (see render_template_fields),
648 # we don't need to create a copy of the MappedOperator here.
649 return self
651 def iter_mapped_dependencies(self) -> Iterator[Operator]:
652 """Upstream dependencies that provide XComs used by this task for task mapping."""
653 from airflow.models.xcom_arg import XComArg
655 for operator, _ in XComArg.iter_xcom_references(self._get_specified_expand_input()):
656 yield operator
658 @cache
659 def get_parse_time_mapped_ti_count(self) -> int:
660 current_count = self._get_specified_expand_input().get_parse_time_mapped_ti_count()
661 try:
662 parent_count = super().get_parse_time_mapped_ti_count()
663 except NotMapped:
664 return current_count
665 return parent_count * current_count
667 def get_mapped_ti_count(self, run_id: str, *, session: Session) -> int:
668 current_count = self._get_specified_expand_input().get_total_map_length(run_id, session=session)
669 try:
670 parent_count = super().get_mapped_ti_count(run_id, session=session)
671 except NotMapped:
672 return current_count
673 return parent_count * current_count
675 def render_template_fields(
676 self,
677 context: Context,
678 jinja_env: jinja2.Environment | None = None,
679 ) -> None:
680 """Template all attributes listed in *self.template_fields*.
682 This updates *context* to reference the map-expanded task and relevant
683 information, without modifying the mapped operator. The expanded task
684 in *context* is then rendered in-place.
686 :param context: Context dict with values to apply on content.
687 :param jinja_env: Jinja environment to use for rendering.
688 """
689 if not jinja_env:
690 jinja_env = self.get_template_env()
692 # Ideally we'd like to pass in session as an argument to this function,
693 # but we can't easily change this function signature since operators
694 # could override this. We can't use @provide_session since it closes and
695 # expunges everything, which we don't want to do when we are so "deep"
696 # in the weeds here. We don't close this session for the same reason.
697 session = settings.Session()
699 mapped_kwargs, seen_oids = self._expand_mapped_kwargs(context, session)
700 unmapped_task = self.unmap(mapped_kwargs)
701 context_update_for_unmapped(context, unmapped_task)
703 # Since the operators that extend `BaseOperator` are not subclasses of
704 # `MappedOperator`, we need to call `_do_render_template_fields` from
705 # the unmapped task in order to call the operator method when we override
706 # it to customize the parsing of nested fields.
707 unmapped_task._do_render_template_fields(
708 parent=unmapped_task,
709 template_fields=self.template_fields,
710 context=context,
711 jinja_env=jinja_env,
712 seen_oids=seen_oids,
713 session=session,
714 )