Coverage for /pythoncovmergedfiles/medio/medio/src/airflow/build/lib/airflow/models/mappedoperator.py: 48%
377 statements
« prev ^ index » next coverage.py v7.0.1, created at 2022-12-25 06:11 +0000
« prev ^ index » next coverage.py v7.0.1, created at 2022-12-25 06:11 +0000
1#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18from __future__ import annotations
20import collections
21import collections.abc
22import contextlib
23import copy
24import datetime
25import warnings
26from typing import TYPE_CHECKING, Any, ClassVar, Collection, Iterable, Iterator, Mapping, Sequence, Union
28import attr
29import pendulum
30from sqlalchemy.orm.session import Session
32from airflow import settings
33from airflow.compat.functools import cache
34from airflow.exceptions import AirflowException, UnmappableOperator
35from airflow.models.abstractoperator import (
36 DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST,
37 DEFAULT_OWNER,
38 DEFAULT_POOL_SLOTS,
39 DEFAULT_PRIORITY_WEIGHT,
40 DEFAULT_QUEUE,
41 DEFAULT_RETRIES,
42 DEFAULT_RETRY_DELAY,
43 DEFAULT_TRIGGER_RULE,
44 DEFAULT_WEIGHT_RULE,
45 AbstractOperator,
46 NotMapped,
47 TaskStateChangeCallback,
48)
49from airflow.models.expandinput import (
50 DictOfListsExpandInput,
51 ExpandInput,
52 ListOfDictsExpandInput,
53 OperatorExpandArgument,
54 OperatorExpandKwargsArgument,
55 is_mappable,
56)
57from airflow.models.param import ParamsDict
58from airflow.models.pool import Pool
59from airflow.serialization.enums import DagAttributeTypes
60from airflow.ti_deps.deps.base_ti_dep import BaseTIDep
61from airflow.ti_deps.deps.mapped_task_expanded import MappedTaskIsExpanded
62from airflow.typing_compat import Literal
63from airflow.utils.context import Context, context_update_for_unmapped
64from airflow.utils.helpers import is_container, prevent_duplicates
65from airflow.utils.operator_resources import Resources
66from airflow.utils.trigger_rule import TriggerRule
67from airflow.utils.types import NOTSET
69if TYPE_CHECKING:
70 import jinja2 # Slow import.
72 from airflow.models.baseoperator import BaseOperator, BaseOperatorLink
73 from airflow.models.dag import DAG
74 from airflow.models.operator import Operator
75 from airflow.models.xcom_arg import XComArg
76 from airflow.utils.task_group import TaskGroup
78ValidationSource = Union[Literal["expand"], Literal["partial"]]
81def validate_mapping_kwargs(op: type[BaseOperator], func: ValidationSource, value: dict[str, Any]) -> None:
82 # use a dict so order of args is same as code order
83 unknown_args = value.copy()
84 for klass in op.mro():
85 init = klass.__init__ # type: ignore[misc]
86 try:
87 param_names = init._BaseOperatorMeta__param_names
88 except AttributeError:
89 continue
90 for name in param_names:
91 value = unknown_args.pop(name, NOTSET)
92 if func != "expand":
93 continue
94 if value is NOTSET:
95 continue
96 if is_mappable(value):
97 continue
98 type_name = type(value).__name__
99 error = f"{op.__name__}.expand() got an unexpected type {type_name!r} for keyword argument {name}"
100 raise ValueError(error)
101 if not unknown_args:
102 return # If we have no args left to check: stop looking at the MRO chain.
104 if len(unknown_args) == 1:
105 error = f"an unexpected keyword argument {unknown_args.popitem()[0]!r}"
106 else:
107 names = ", ".join(repr(n) for n in unknown_args)
108 error = f"unexpected keyword arguments {names}"
109 raise TypeError(f"{op.__name__}.{func}() got {error}")
112def ensure_xcomarg_return_value(arg: Any) -> None:
113 from airflow.models.xcom_arg import XCOM_RETURN_KEY, XComArg
115 if isinstance(arg, XComArg):
116 for operator, key in arg.iter_references():
117 if key != XCOM_RETURN_KEY:
118 raise ValueError(f"cannot map over XCom with custom key {key!r} from {operator}")
119 elif not is_container(arg):
120 return
121 elif isinstance(arg, collections.abc.Mapping):
122 for v in arg.values():
123 ensure_xcomarg_return_value(v)
124 elif isinstance(arg, collections.abc.Iterable):
125 for v in arg:
126 ensure_xcomarg_return_value(v)
129@attr.define(kw_only=True, repr=False)
130class OperatorPartial:
131 """An "intermediate state" returned by ``BaseOperator.partial()``.
133 This only exists at DAG-parsing time; the only intended usage is for the
134 user to call ``.expand()`` on it at some point (usually in a method chain) to
135 create a ``MappedOperator`` to add into the DAG.
136 """
138 operator_class: type[BaseOperator]
139 kwargs: dict[str, Any]
140 params: ParamsDict | dict
142 _expand_called: bool = False # Set when expand() is called to ease user debugging.
144 def __attrs_post_init__(self):
145 from airflow.operators.subdag import SubDagOperator
147 if issubclass(self.operator_class, SubDagOperator):
148 raise TypeError("Mapping over deprecated SubDagOperator is not supported")
149 validate_mapping_kwargs(self.operator_class, "partial", self.kwargs)
151 def __repr__(self) -> str:
152 args = ", ".join(f"{k}={v!r}" for k, v in self.kwargs.items())
153 return f"{self.operator_class.__name__}.partial({args})"
155 def __del__(self):
156 if not self._expand_called:
157 try:
158 task_id = repr(self.kwargs["task_id"])
159 except KeyError:
160 task_id = f"at {hex(id(self))}"
161 warnings.warn(f"Task {task_id} was never mapped!")
163 def expand(self, **mapped_kwargs: OperatorExpandArgument) -> MappedOperator:
164 if not mapped_kwargs:
165 raise TypeError("no arguments to expand against")
166 validate_mapping_kwargs(self.operator_class, "expand", mapped_kwargs)
167 prevent_duplicates(self.kwargs, mapped_kwargs, fail_reason="unmappable or already specified")
168 # Since the input is already checked at parse time, we can set strict
169 # to False to skip the checks on execution.
170 return self._expand(DictOfListsExpandInput(mapped_kwargs), strict=False)
172 def expand_kwargs(self, kwargs: OperatorExpandKwargsArgument, *, strict: bool = True) -> MappedOperator:
173 from airflow.models.xcom_arg import XComArg
175 if isinstance(kwargs, collections.abc.Sequence):
176 for item in kwargs:
177 if not isinstance(item, (XComArg, collections.abc.Mapping)):
178 raise TypeError(f"expected XComArg or list[dict], not {type(kwargs).__name__}")
179 elif not isinstance(kwargs, XComArg):
180 raise TypeError(f"expected XComArg or list[dict], not {type(kwargs).__name__}")
181 return self._expand(ListOfDictsExpandInput(kwargs), strict=strict)
183 def _expand(self, expand_input: ExpandInput, *, strict: bool) -> MappedOperator:
184 from airflow.operators.empty import EmptyOperator
186 self._expand_called = True
187 ensure_xcomarg_return_value(expand_input.value)
189 partial_kwargs = self.kwargs.copy()
190 task_id = partial_kwargs.pop("task_id")
191 dag = partial_kwargs.pop("dag")
192 task_group = partial_kwargs.pop("task_group")
193 start_date = partial_kwargs.pop("start_date")
194 end_date = partial_kwargs.pop("end_date")
196 try:
197 operator_name = self.operator_class.custom_operator_name # type: ignore
198 except AttributeError:
199 operator_name = self.operator_class.__name__
201 op = MappedOperator(
202 operator_class=self.operator_class,
203 expand_input=expand_input,
204 partial_kwargs=partial_kwargs,
205 task_id=task_id,
206 params=self.params,
207 deps=MappedOperator.deps_for(self.operator_class),
208 operator_extra_links=self.operator_class.operator_extra_links,
209 template_ext=self.operator_class.template_ext,
210 template_fields=self.operator_class.template_fields,
211 template_fields_renderers=self.operator_class.template_fields_renderers,
212 ui_color=self.operator_class.ui_color,
213 ui_fgcolor=self.operator_class.ui_fgcolor,
214 is_empty=issubclass(self.operator_class, EmptyOperator),
215 task_module=self.operator_class.__module__,
216 task_type=self.operator_class.__name__,
217 operator_name=operator_name,
218 dag=dag,
219 task_group=task_group,
220 start_date=start_date,
221 end_date=end_date,
222 disallow_kwargs_override=strict,
223 # For classic operators, this points to expand_input because kwargs
224 # to BaseOperator.expand() contribute to operator arguments.
225 expand_input_attr="expand_input",
226 )
227 return op
230@attr.define(
231 kw_only=True,
232 # Disable custom __getstate__ and __setstate__ generation since it interacts
233 # badly with Airflow's DAG serialization and pickling. When a mapped task is
234 # deserialized, subclasses are coerced into MappedOperator, but when it goes
235 # through DAG pickling, all attributes defined in the subclasses are dropped
236 # by attrs's custom state management. Since attrs does not do anything too
237 # special here (the logic is only important for slots=True), we use Python's
238 # built-in implementation, which works (as proven by good old BaseOperator).
239 getstate_setstate=False,
240)
241class MappedOperator(AbstractOperator):
242 """Object representing a mapped operator in a DAG."""
244 # This attribute serves double purpose. For a "normal" operator instance
245 # loaded from DAG, this holds the underlying non-mapped operator class that
246 # can be used to create an unmapped operator for execution. For an operator
247 # recreated from a serialized DAG, however, this holds the serialized data
248 # that can be used to unmap this into a SerializedBaseOperator.
249 operator_class: type[BaseOperator] | dict[str, Any]
251 expand_input: ExpandInput
252 partial_kwargs: dict[str, Any]
254 # Needed for serialization.
255 task_id: str
256 params: ParamsDict | dict
257 deps: frozenset[BaseTIDep]
258 operator_extra_links: Collection[BaseOperatorLink]
259 template_ext: Sequence[str]
260 template_fields: Collection[str]
261 template_fields_renderers: dict[str, str]
262 ui_color: str
263 ui_fgcolor: str
264 _is_empty: bool
265 _task_module: str
266 _task_type: str
267 _operator_name: str
269 dag: DAG | None
270 task_group: TaskGroup | None
271 start_date: pendulum.DateTime | None
272 end_date: pendulum.DateTime | None
273 upstream_task_ids: set[str] = attr.ib(factory=set, init=False)
274 downstream_task_ids: set[str] = attr.ib(factory=set, init=False)
276 _disallow_kwargs_override: bool
277 """Whether execution fails if ``expand_input`` has duplicates to ``partial_kwargs``.
279 If *False*, values from ``expand_input`` under duplicate keys override those
280 under corresponding keys in ``partial_kwargs``.
281 """
283 _expand_input_attr: str
284 """Where to get kwargs to calculate expansion length against.
286 This should be a name to call ``getattr()`` on.
287 """
289 subdag: None = None # Since we don't support SubDagOperator, this is always None.
291 HIDE_ATTRS_FROM_UI: ClassVar[frozenset[str]] = AbstractOperator.HIDE_ATTRS_FROM_UI | frozenset(
292 (
293 "parse_time_mapped_ti_count",
294 "operator_class",
295 )
296 )
298 def __hash__(self):
299 return id(self)
301 def __repr__(self):
302 return f"<Mapped({self._task_type}): {self.task_id}>"
304 def __attrs_post_init__(self):
305 from airflow.models.xcom_arg import XComArg
307 if self.get_closest_mapped_task_group() is not None:
308 raise NotImplementedError("operator expansion in an expanded task group is not yet supported")
310 if self.task_group:
311 self.task_group.add(self)
312 if self.dag:
313 self.dag.add_task(self)
314 XComArg.apply_upstream_relationship(self, self.expand_input.value)
315 for k, v in self.partial_kwargs.items():
316 if k in self.template_fields:
317 XComArg.apply_upstream_relationship(self, v)
318 if self.partial_kwargs.get("sla") is not None:
319 raise AirflowException(
320 f"SLAs are unsupported with mapped tasks. Please set `sla=None` for task "
321 f"{self.task_id!r}."
322 )
324 @classmethod
325 @cache
326 def get_serialized_fields(cls):
327 # Not using 'cls' here since we only want to serialize base fields.
328 return frozenset(attr.fields_dict(MappedOperator)) - {
329 "dag",
330 "deps",
331 "expand_input", # This is needed to be able to accept XComArg.
332 "subdag",
333 "task_group",
334 "upstream_task_ids",
335 }
337 @staticmethod
338 @cache
339 def deps_for(operator_class: type[BaseOperator]) -> frozenset[BaseTIDep]:
340 operator_deps = operator_class.deps
341 if not isinstance(operator_deps, collections.abc.Set):
342 raise UnmappableOperator(
343 f"'deps' must be a set defined as a class-level variable on {operator_class.__name__}, "
344 f"not a {type(operator_deps).__name__}"
345 )
346 return operator_deps | {MappedTaskIsExpanded()}
348 @property
349 def task_type(self) -> str:
350 """Implementing Operator."""
351 return self._task_type
353 @property
354 def operator_name(self) -> str:
355 return self._operator_name
357 @property
358 def inherits_from_empty_operator(self) -> bool:
359 """Implementing Operator."""
360 return self._is_empty
362 @property
363 def roots(self) -> Sequence[AbstractOperator]:
364 """Implementing DAGNode."""
365 return [self]
367 @property
368 def leaves(self) -> Sequence[AbstractOperator]:
369 """Implementing DAGNode."""
370 return [self]
372 @property
373 def owner(self) -> str: # type: ignore[override]
374 return self.partial_kwargs.get("owner", DEFAULT_OWNER)
376 @property
377 def email(self) -> None | str | Iterable[str]:
378 return self.partial_kwargs.get("email")
380 @property
381 def trigger_rule(self) -> TriggerRule:
382 return self.partial_kwargs.get("trigger_rule", DEFAULT_TRIGGER_RULE)
384 @property
385 def depends_on_past(self) -> bool:
386 return bool(self.partial_kwargs.get("depends_on_past"))
388 @property
389 def ignore_first_depends_on_past(self) -> bool:
390 value = self.partial_kwargs.get("ignore_first_depends_on_past", DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST)
391 return bool(value)
393 @property
394 def wait_for_downstream(self) -> bool:
395 return bool(self.partial_kwargs.get("wait_for_downstream"))
397 @property
398 def retries(self) -> int | None:
399 return self.partial_kwargs.get("retries", DEFAULT_RETRIES)
401 @property
402 def queue(self) -> str:
403 return self.partial_kwargs.get("queue", DEFAULT_QUEUE)
405 @property
406 def pool(self) -> str:
407 return self.partial_kwargs.get("pool", Pool.DEFAULT_POOL_NAME)
409 @property
410 def pool_slots(self) -> str | None:
411 return self.partial_kwargs.get("pool_slots", DEFAULT_POOL_SLOTS)
413 @property
414 def execution_timeout(self) -> datetime.timedelta | None:
415 return self.partial_kwargs.get("execution_timeout")
417 @property
418 def max_retry_delay(self) -> datetime.timedelta | None:
419 return self.partial_kwargs.get("max_retry_delay")
421 @property
422 def retry_delay(self) -> datetime.timedelta:
423 return self.partial_kwargs.get("retry_delay", DEFAULT_RETRY_DELAY)
425 @property
426 def retry_exponential_backoff(self) -> bool:
427 return bool(self.partial_kwargs.get("retry_exponential_backoff"))
429 @property
430 def priority_weight(self) -> int: # type: ignore[override]
431 return self.partial_kwargs.get("priority_weight", DEFAULT_PRIORITY_WEIGHT)
433 @property
434 def weight_rule(self) -> int: # type: ignore[override]
435 return self.partial_kwargs.get("weight_rule", DEFAULT_WEIGHT_RULE)
437 @property
438 def sla(self) -> datetime.timedelta | None:
439 return self.partial_kwargs.get("sla")
441 @property
442 def max_active_tis_per_dag(self) -> int | None:
443 return self.partial_kwargs.get("max_active_tis_per_dag")
445 @property
446 def resources(self) -> Resources | None:
447 return self.partial_kwargs.get("resources")
449 @property
450 def on_execute_callback(self) -> None | TaskStateChangeCallback | list[TaskStateChangeCallback]:
451 return self.partial_kwargs.get("on_execute_callback")
453 @on_execute_callback.setter
454 def on_execute_callback(self, value: TaskStateChangeCallback | None) -> None:
455 self.partial_kwargs["on_execute_callback"] = value
457 @property
458 def on_failure_callback(self) -> None | TaskStateChangeCallback | list[TaskStateChangeCallback]:
459 return self.partial_kwargs.get("on_failure_callback")
461 @on_failure_callback.setter
462 def on_failure_callback(self, value: TaskStateChangeCallback | None) -> None:
463 self.partial_kwargs["on_failure_callback"] = value
465 @property
466 def on_retry_callback(self) -> None | TaskStateChangeCallback | list[TaskStateChangeCallback]:
467 return self.partial_kwargs.get("on_retry_callback")
469 @on_retry_callback.setter
470 def on_retry_callback(self, value: TaskStateChangeCallback | None) -> None:
471 self.partial_kwargs["on_retry_callback"] = value
473 @property
474 def on_success_callback(self) -> None | TaskStateChangeCallback | list[TaskStateChangeCallback]:
475 return self.partial_kwargs.get("on_success_callback")
477 @on_success_callback.setter
478 def on_success_callback(self, value: TaskStateChangeCallback | None) -> None:
479 self.partial_kwargs["on_success_callback"] = value
481 @property
482 def run_as_user(self) -> str | None:
483 return self.partial_kwargs.get("run_as_user")
485 @property
486 def executor_config(self) -> dict:
487 return self.partial_kwargs.get("executor_config", {})
489 @property # type: ignore[override]
490 def inlets(self) -> list[Any]: # type: ignore[override]
491 return self.partial_kwargs.get("inlets", [])
493 @inlets.setter
494 def inlets(self, value: list[Any]) -> None: # type: ignore[override]
495 self.partial_kwargs["inlets"] = value
497 @property # type: ignore[override]
498 def outlets(self) -> list[Any]: # type: ignore[override]
499 return self.partial_kwargs.get("outlets", [])
501 @outlets.setter
502 def outlets(self, value: list[Any]) -> None: # type: ignore[override]
503 self.partial_kwargs["outlets"] = value
505 @property
506 def doc(self) -> str | None:
507 return self.partial_kwargs.get("doc")
509 @property
510 def doc_md(self) -> str | None:
511 return self.partial_kwargs.get("doc_md")
513 @property
514 def doc_json(self) -> str | None:
515 return self.partial_kwargs.get("doc_json")
517 @property
518 def doc_yaml(self) -> str | None:
519 return self.partial_kwargs.get("doc_yaml")
521 @property
522 def doc_rst(self) -> str | None:
523 return self.partial_kwargs.get("doc_rst")
525 def get_dag(self) -> DAG | None:
526 """Implementing Operator."""
527 return self.dag
529 @property
530 def output(self) -> XComArg:
531 """Returns reference to XCom pushed by current operator"""
532 from airflow.models.xcom_arg import XComArg
534 return XComArg(operator=self)
536 def serialize_for_task_group(self) -> tuple[DagAttributeTypes, Any]:
537 """Implementing DAGNode."""
538 return DagAttributeTypes.OP, self.task_id
540 def _expand_mapped_kwargs(self, context: Context, session: Session) -> tuple[Mapping[str, Any], set[int]]:
541 """Get the kwargs to create the unmapped operator.
543 This exists because taskflow operators expand against op_kwargs, not the
544 entire operator kwargs dict.
545 """
546 return self._get_specified_expand_input().resolve(context, session)
548 def _get_unmap_kwargs(self, mapped_kwargs: Mapping[str, Any], *, strict: bool) -> dict[str, Any]:
549 """Get init kwargs to unmap the underlying operator class.
551 :param mapped_kwargs: The dict returned by ``_expand_mapped_kwargs``.
552 """
553 if strict:
554 prevent_duplicates(
555 self.partial_kwargs,
556 mapped_kwargs,
557 fail_reason="unmappable or already specified",
558 )
560 # If params appears in the mapped kwargs, we need to merge it into the
561 # partial params, overriding existing keys.
562 params = copy.copy(self.params)
563 with contextlib.suppress(KeyError):
564 params.update(mapped_kwargs["params"])
566 # Ordering is significant; mapped kwargs should override partial ones,
567 # and the specially handled params should be respected.
568 return {
569 "task_id": self.task_id,
570 "dag": self.dag,
571 "task_group": self.task_group,
572 "start_date": self.start_date,
573 "end_date": self.end_date,
574 **self.partial_kwargs,
575 **mapped_kwargs,
576 "params": params,
577 }
579 def unmap(self, resolve: None | Mapping[str, Any] | tuple[Context, Session]) -> BaseOperator:
580 """Get the "normal" Operator after applying the current mapping.
582 The *resolve* argument is only used if ``operator_class`` is a real
583 class, i.e. if this operator is not serialized. If ``operator_class`` is
584 not a class (i.e. this DAG has been deserialized), this returns a
585 SerializedBaseOperator that "looks like" the actual unmapping result.
587 If *resolve* is a two-tuple (context, session), the information is used
588 to resolve the mapped arguments into init arguments. If it is a mapping,
589 no resolving happens, the mapping directly provides those init arguments
590 resolved from mapped kwargs.
592 :meta private:
593 """
594 if isinstance(self.operator_class, type):
595 if isinstance(resolve, collections.abc.Mapping):
596 kwargs = resolve
597 elif resolve is not None:
598 kwargs, _ = self._expand_mapped_kwargs(*resolve)
599 else:
600 raise RuntimeError("cannot unmap a non-serialized operator without context")
601 kwargs = self._get_unmap_kwargs(kwargs, strict=self._disallow_kwargs_override)
602 op = self.operator_class(**kwargs, _airflow_from_mapped=True)
603 # We need to overwrite task_id here because BaseOperator further
604 # mangles the task_id based on the task hierarchy (namely, group_id
605 # is prepended, and '__N' appended to deduplicate). This is hacky,
606 # but better than duplicating the whole mangling logic.
607 op.task_id = self.task_id
608 return op
610 # After a mapped operator is serialized, there's no real way to actually
611 # unmap it since we've lost access to the underlying operator class.
612 # This tries its best to simply "forward" all the attributes on this
613 # mapped operator to a new SerializedBaseOperator instance.
614 from airflow.serialization.serialized_objects import SerializedBaseOperator
616 op = SerializedBaseOperator(task_id=self.task_id, params=self.params, _airflow_from_mapped=True)
617 SerializedBaseOperator.populate_operator(op, self.operator_class)
618 return op
620 def _get_specified_expand_input(self) -> ExpandInput:
621 """Input received from the expand call on the operator."""
622 return getattr(self, self._expand_input_attr)
624 def prepare_for_execution(self) -> MappedOperator:
625 # Since a mapped operator cannot be used for execution, and an unmapped
626 # BaseOperator needs to be created later (see render_template_fields),
627 # we don't need to create a copy of the MappedOperator here.
628 return self
630 def iter_mapped_dependencies(self) -> Iterator[Operator]:
631 """Upstream dependencies that provide XComs used by this task for task mapping."""
632 from airflow.models.xcom_arg import XComArg
634 for operator, _ in XComArg.iter_xcom_references(self._get_specified_expand_input()):
635 yield operator
637 @cache
638 def get_parse_time_mapped_ti_count(self) -> int:
639 current_count = self._get_specified_expand_input().get_parse_time_mapped_ti_count()
640 try:
641 parent_count = super().get_parse_time_mapped_ti_count()
642 except NotMapped:
643 return current_count
644 return parent_count * current_count
646 def get_mapped_ti_count(self, run_id: str, *, session: Session) -> int:
647 current_count = self._get_specified_expand_input().get_total_map_length(run_id, session=session)
648 try:
649 parent_count = super().get_mapped_ti_count(run_id, session=session)
650 except NotMapped:
651 return current_count
652 return parent_count * current_count
654 def render_template_fields(
655 self,
656 context: Context,
657 jinja_env: jinja2.Environment | None = None,
658 ) -> None:
659 """Template all attributes listed in *self.template_fields*.
661 This updates *context* to reference the map-expanded task and relevant
662 information, without modifying the mapped operator. The expanded task
663 in *context* is then rendered in-place.
665 :param context: Context dict with values to apply on content.
666 :param jinja_env: Jinja environment to use for rendering.
667 """
668 if not jinja_env:
669 jinja_env = self.get_template_env()
671 # Ideally we'd like to pass in session as an argument to this function,
672 # but we can't easily change this function signature since operators
673 # could override this. We can't use @provide_session since it closes and
674 # expunges everything, which we don't want to do when we are so "deep"
675 # in the weeds here. We don't close this session for the same reason.
676 session = settings.Session()
678 mapped_kwargs, seen_oids = self._expand_mapped_kwargs(context, session)
679 unmapped_task = self.unmap(mapped_kwargs)
680 context_update_for_unmapped(context, unmapped_task)
682 self._do_render_template_fields(
683 parent=unmapped_task,
684 template_fields=self.template_fields,
685 context=context,
686 jinja_env=jinja_env,
687 seen_oids=seen_oids,
688 session=session,
689 )