Coverage for /pythoncovmergedfiles/medio/medio/src/airflow/airflow/models/mappedoperator.py: 48%

377 statements  

« prev     ^ index     » next       coverage.py v7.0.1, created at 2022-12-25 06:11 +0000

1# 

2# Licensed to the Apache Software Foundation (ASF) under one 

3# or more contributor license agreements. See the NOTICE file 

4# distributed with this work for additional information 

5# regarding copyright ownership. The ASF licenses this file 

6# to you under the Apache License, Version 2.0 (the 

7# "License"); you may not use this file except in compliance 

8# with the License. You may obtain a copy of the License at 

9# 

10# http://www.apache.org/licenses/LICENSE-2.0 

11# 

12# Unless required by applicable law or agreed to in writing, 

13# software distributed under the License is distributed on an 

14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 

15# KIND, either express or implied. See the License for the 

16# specific language governing permissions and limitations 

17# under the License. 

18from __future__ import annotations 

19 

20import collections 

21import collections.abc 

22import contextlib 

23import copy 

24import datetime 

25import warnings 

26from typing import TYPE_CHECKING, Any, ClassVar, Collection, Iterable, Iterator, Mapping, Sequence, Union 

27 

28import attr 

29import pendulum 

30from sqlalchemy.orm.session import Session 

31 

32from airflow import settings 

33from airflow.compat.functools import cache 

34from airflow.exceptions import AirflowException, UnmappableOperator 

35from airflow.models.abstractoperator import ( 

36 DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST, 

37 DEFAULT_OWNER, 

38 DEFAULT_POOL_SLOTS, 

39 DEFAULT_PRIORITY_WEIGHT, 

40 DEFAULT_QUEUE, 

41 DEFAULT_RETRIES, 

42 DEFAULT_RETRY_DELAY, 

43 DEFAULT_TRIGGER_RULE, 

44 DEFAULT_WEIGHT_RULE, 

45 AbstractOperator, 

46 NotMapped, 

47 TaskStateChangeCallback, 

48) 

49from airflow.models.expandinput import ( 

50 DictOfListsExpandInput, 

51 ExpandInput, 

52 ListOfDictsExpandInput, 

53 OperatorExpandArgument, 

54 OperatorExpandKwargsArgument, 

55 is_mappable, 

56) 

57from airflow.models.param import ParamsDict 

58from airflow.models.pool import Pool 

59from airflow.serialization.enums import DagAttributeTypes 

60from airflow.ti_deps.deps.base_ti_dep import BaseTIDep 

61from airflow.ti_deps.deps.mapped_task_expanded import MappedTaskIsExpanded 

62from airflow.typing_compat import Literal 

63from airflow.utils.context import Context, context_update_for_unmapped 

64from airflow.utils.helpers import is_container, prevent_duplicates 

65from airflow.utils.operator_resources import Resources 

66from airflow.utils.trigger_rule import TriggerRule 

67from airflow.utils.types import NOTSET 

68 

69if TYPE_CHECKING: 

70 import jinja2 # Slow import. 

71 

72 from airflow.models.baseoperator import BaseOperator, BaseOperatorLink 

73 from airflow.models.dag import DAG 

74 from airflow.models.operator import Operator 

75 from airflow.models.xcom_arg import XComArg 

76 from airflow.utils.task_group import TaskGroup 

77 

78ValidationSource = Union[Literal["expand"], Literal["partial"]] 

79 

80 

81def validate_mapping_kwargs(op: type[BaseOperator], func: ValidationSource, value: dict[str, Any]) -> None: 

82 # use a dict so order of args is same as code order 

83 unknown_args = value.copy() 

84 for klass in op.mro(): 

85 init = klass.__init__ # type: ignore[misc] 

86 try: 

87 param_names = init._BaseOperatorMeta__param_names 

88 except AttributeError: 

89 continue 

90 for name in param_names: 

91 value = unknown_args.pop(name, NOTSET) 

92 if func != "expand": 

93 continue 

94 if value is NOTSET: 

95 continue 

96 if is_mappable(value): 

97 continue 

98 type_name = type(value).__name__ 

99 error = f"{op.__name__}.expand() got an unexpected type {type_name!r} for keyword argument {name}" 

100 raise ValueError(error) 

101 if not unknown_args: 

102 return # If we have no args left to check: stop looking at the MRO chain. 

103 

104 if len(unknown_args) == 1: 

105 error = f"an unexpected keyword argument {unknown_args.popitem()[0]!r}" 

106 else: 

107 names = ", ".join(repr(n) for n in unknown_args) 

108 error = f"unexpected keyword arguments {names}" 

109 raise TypeError(f"{op.__name__}.{func}() got {error}") 

110 

111 

112def ensure_xcomarg_return_value(arg: Any) -> None: 

113 from airflow.models.xcom_arg import XCOM_RETURN_KEY, XComArg 

114 

115 if isinstance(arg, XComArg): 

116 for operator, key in arg.iter_references(): 

117 if key != XCOM_RETURN_KEY: 

118 raise ValueError(f"cannot map over XCom with custom key {key!r} from {operator}") 

119 elif not is_container(arg): 

120 return 

121 elif isinstance(arg, collections.abc.Mapping): 

122 for v in arg.values(): 

123 ensure_xcomarg_return_value(v) 

124 elif isinstance(arg, collections.abc.Iterable): 

125 for v in arg: 

126 ensure_xcomarg_return_value(v) 

127 

128 

129@attr.define(kw_only=True, repr=False) 

130class OperatorPartial: 

131 """An "intermediate state" returned by ``BaseOperator.partial()``. 

132 

133 This only exists at DAG-parsing time; the only intended usage is for the 

134 user to call ``.expand()`` on it at some point (usually in a method chain) to 

135 create a ``MappedOperator`` to add into the DAG. 

136 """ 

137 

138 operator_class: type[BaseOperator] 

139 kwargs: dict[str, Any] 

140 params: ParamsDict | dict 

141 

142 _expand_called: bool = False # Set when expand() is called to ease user debugging. 

143 

144 def __attrs_post_init__(self): 

145 from airflow.operators.subdag import SubDagOperator 

146 

147 if issubclass(self.operator_class, SubDagOperator): 

148 raise TypeError("Mapping over deprecated SubDagOperator is not supported") 

149 validate_mapping_kwargs(self.operator_class, "partial", self.kwargs) 

150 

151 def __repr__(self) -> str: 

152 args = ", ".join(f"{k}={v!r}" for k, v in self.kwargs.items()) 

153 return f"{self.operator_class.__name__}.partial({args})" 

154 

155 def __del__(self): 

156 if not self._expand_called: 

157 try: 

158 task_id = repr(self.kwargs["task_id"]) 

159 except KeyError: 

160 task_id = f"at {hex(id(self))}" 

161 warnings.warn(f"Task {task_id} was never mapped!") 

162 

163 def expand(self, **mapped_kwargs: OperatorExpandArgument) -> MappedOperator: 

164 if not mapped_kwargs: 

165 raise TypeError("no arguments to expand against") 

166 validate_mapping_kwargs(self.operator_class, "expand", mapped_kwargs) 

167 prevent_duplicates(self.kwargs, mapped_kwargs, fail_reason="unmappable or already specified") 

168 # Since the input is already checked at parse time, we can set strict 

169 # to False to skip the checks on execution. 

170 return self._expand(DictOfListsExpandInput(mapped_kwargs), strict=False) 

171 

172 def expand_kwargs(self, kwargs: OperatorExpandKwargsArgument, *, strict: bool = True) -> MappedOperator: 

173 from airflow.models.xcom_arg import XComArg 

174 

175 if isinstance(kwargs, collections.abc.Sequence): 

176 for item in kwargs: 

177 if not isinstance(item, (XComArg, collections.abc.Mapping)): 

178 raise TypeError(f"expected XComArg or list[dict], not {type(kwargs).__name__}") 

179 elif not isinstance(kwargs, XComArg): 

180 raise TypeError(f"expected XComArg or list[dict], not {type(kwargs).__name__}") 

181 return self._expand(ListOfDictsExpandInput(kwargs), strict=strict) 

182 

183 def _expand(self, expand_input: ExpandInput, *, strict: bool) -> MappedOperator: 

184 from airflow.operators.empty import EmptyOperator 

185 

186 self._expand_called = True 

187 ensure_xcomarg_return_value(expand_input.value) 

188 

189 partial_kwargs = self.kwargs.copy() 

190 task_id = partial_kwargs.pop("task_id") 

191 dag = partial_kwargs.pop("dag") 

192 task_group = partial_kwargs.pop("task_group") 

193 start_date = partial_kwargs.pop("start_date") 

194 end_date = partial_kwargs.pop("end_date") 

195 

196 try: 

197 operator_name = self.operator_class.custom_operator_name # type: ignore 

198 except AttributeError: 

199 operator_name = self.operator_class.__name__ 

200 

201 op = MappedOperator( 

202 operator_class=self.operator_class, 

203 expand_input=expand_input, 

204 partial_kwargs=partial_kwargs, 

205 task_id=task_id, 

206 params=self.params, 

207 deps=MappedOperator.deps_for(self.operator_class), 

208 operator_extra_links=self.operator_class.operator_extra_links, 

209 template_ext=self.operator_class.template_ext, 

210 template_fields=self.operator_class.template_fields, 

211 template_fields_renderers=self.operator_class.template_fields_renderers, 

212 ui_color=self.operator_class.ui_color, 

213 ui_fgcolor=self.operator_class.ui_fgcolor, 

214 is_empty=issubclass(self.operator_class, EmptyOperator), 

215 task_module=self.operator_class.__module__, 

216 task_type=self.operator_class.__name__, 

217 operator_name=operator_name, 

218 dag=dag, 

219 task_group=task_group, 

220 start_date=start_date, 

221 end_date=end_date, 

222 disallow_kwargs_override=strict, 

223 # For classic operators, this points to expand_input because kwargs 

224 # to BaseOperator.expand() contribute to operator arguments. 

225 expand_input_attr="expand_input", 

226 ) 

227 return op 

228 

229 

230@attr.define( 

231 kw_only=True, 

232 # Disable custom __getstate__ and __setstate__ generation since it interacts 

233 # badly with Airflow's DAG serialization and pickling. When a mapped task is 

234 # deserialized, subclasses are coerced into MappedOperator, but when it goes 

235 # through DAG pickling, all attributes defined in the subclasses are dropped 

236 # by attrs's custom state management. Since attrs does not do anything too 

237 # special here (the logic is only important for slots=True), we use Python's 

238 # built-in implementation, which works (as proven by good old BaseOperator). 

239 getstate_setstate=False, 

240) 

241class MappedOperator(AbstractOperator): 

242 """Object representing a mapped operator in a DAG.""" 

243 

244 # This attribute serves double purpose. For a "normal" operator instance 

245 # loaded from DAG, this holds the underlying non-mapped operator class that 

246 # can be used to create an unmapped operator for execution. For an operator 

247 # recreated from a serialized DAG, however, this holds the serialized data 

248 # that can be used to unmap this into a SerializedBaseOperator. 

249 operator_class: type[BaseOperator] | dict[str, Any] 

250 

251 expand_input: ExpandInput 

252 partial_kwargs: dict[str, Any] 

253 

254 # Needed for serialization. 

255 task_id: str 

256 params: ParamsDict | dict 

257 deps: frozenset[BaseTIDep] 

258 operator_extra_links: Collection[BaseOperatorLink] 

259 template_ext: Sequence[str] 

260 template_fields: Collection[str] 

261 template_fields_renderers: dict[str, str] 

262 ui_color: str 

263 ui_fgcolor: str 

264 _is_empty: bool 

265 _task_module: str 

266 _task_type: str 

267 _operator_name: str 

268 

269 dag: DAG | None 

270 task_group: TaskGroup | None 

271 start_date: pendulum.DateTime | None 

272 end_date: pendulum.DateTime | None 

273 upstream_task_ids: set[str] = attr.ib(factory=set, init=False) 

274 downstream_task_ids: set[str] = attr.ib(factory=set, init=False) 

275 

276 _disallow_kwargs_override: bool 

277 """Whether execution fails if ``expand_input`` has duplicates to ``partial_kwargs``. 

278 

279 If *False*, values from ``expand_input`` under duplicate keys override those 

280 under corresponding keys in ``partial_kwargs``. 

281 """ 

282 

283 _expand_input_attr: str 

284 """Where to get kwargs to calculate expansion length against. 

285 

286 This should be a name to call ``getattr()`` on. 

287 """ 

288 

289 subdag: None = None # Since we don't support SubDagOperator, this is always None. 

290 

291 HIDE_ATTRS_FROM_UI: ClassVar[frozenset[str]] = AbstractOperator.HIDE_ATTRS_FROM_UI | frozenset( 

292 ( 

293 "parse_time_mapped_ti_count", 

294 "operator_class", 

295 ) 

296 ) 

297 

298 def __hash__(self): 

299 return id(self) 

300 

301 def __repr__(self): 

302 return f"<Mapped({self._task_type}): {self.task_id}>" 

303 

304 def __attrs_post_init__(self): 

305 from airflow.models.xcom_arg import XComArg 

306 

307 if self.get_closest_mapped_task_group() is not None: 

308 raise NotImplementedError("operator expansion in an expanded task group is not yet supported") 

309 

310 if self.task_group: 

311 self.task_group.add(self) 

312 if self.dag: 

313 self.dag.add_task(self) 

314 XComArg.apply_upstream_relationship(self, self.expand_input.value) 

315 for k, v in self.partial_kwargs.items(): 

316 if k in self.template_fields: 

317 XComArg.apply_upstream_relationship(self, v) 

318 if self.partial_kwargs.get("sla") is not None: 

319 raise AirflowException( 

320 f"SLAs are unsupported with mapped tasks. Please set `sla=None` for task " 

321 f"{self.task_id!r}." 

322 ) 

323 

324 @classmethod 

325 @cache 

326 def get_serialized_fields(cls): 

327 # Not using 'cls' here since we only want to serialize base fields. 

328 return frozenset(attr.fields_dict(MappedOperator)) - { 

329 "dag", 

330 "deps", 

331 "expand_input", # This is needed to be able to accept XComArg. 

332 "subdag", 

333 "task_group", 

334 "upstream_task_ids", 

335 } 

336 

337 @staticmethod 

338 @cache 

339 def deps_for(operator_class: type[BaseOperator]) -> frozenset[BaseTIDep]: 

340 operator_deps = operator_class.deps 

341 if not isinstance(operator_deps, collections.abc.Set): 

342 raise UnmappableOperator( 

343 f"'deps' must be a set defined as a class-level variable on {operator_class.__name__}, " 

344 f"not a {type(operator_deps).__name__}" 

345 ) 

346 return operator_deps | {MappedTaskIsExpanded()} 

347 

348 @property 

349 def task_type(self) -> str: 

350 """Implementing Operator.""" 

351 return self._task_type 

352 

353 @property 

354 def operator_name(self) -> str: 

355 return self._operator_name 

356 

357 @property 

358 def inherits_from_empty_operator(self) -> bool: 

359 """Implementing Operator.""" 

360 return self._is_empty 

361 

362 @property 

363 def roots(self) -> Sequence[AbstractOperator]: 

364 """Implementing DAGNode.""" 

365 return [self] 

366 

367 @property 

368 def leaves(self) -> Sequence[AbstractOperator]: 

369 """Implementing DAGNode.""" 

370 return [self] 

371 

372 @property 

373 def owner(self) -> str: # type: ignore[override] 

374 return self.partial_kwargs.get("owner", DEFAULT_OWNER) 

375 

376 @property 

377 def email(self) -> None | str | Iterable[str]: 

378 return self.partial_kwargs.get("email") 

379 

380 @property 

381 def trigger_rule(self) -> TriggerRule: 

382 return self.partial_kwargs.get("trigger_rule", DEFAULT_TRIGGER_RULE) 

383 

384 @property 

385 def depends_on_past(self) -> bool: 

386 return bool(self.partial_kwargs.get("depends_on_past")) 

387 

388 @property 

389 def ignore_first_depends_on_past(self) -> bool: 

390 value = self.partial_kwargs.get("ignore_first_depends_on_past", DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST) 

391 return bool(value) 

392 

393 @property 

394 def wait_for_downstream(self) -> bool: 

395 return bool(self.partial_kwargs.get("wait_for_downstream")) 

396 

397 @property 

398 def retries(self) -> int | None: 

399 return self.partial_kwargs.get("retries", DEFAULT_RETRIES) 

400 

401 @property 

402 def queue(self) -> str: 

403 return self.partial_kwargs.get("queue", DEFAULT_QUEUE) 

404 

405 @property 

406 def pool(self) -> str: 

407 return self.partial_kwargs.get("pool", Pool.DEFAULT_POOL_NAME) 

408 

409 @property 

410 def pool_slots(self) -> str | None: 

411 return self.partial_kwargs.get("pool_slots", DEFAULT_POOL_SLOTS) 

412 

413 @property 

414 def execution_timeout(self) -> datetime.timedelta | None: 

415 return self.partial_kwargs.get("execution_timeout") 

416 

417 @property 

418 def max_retry_delay(self) -> datetime.timedelta | None: 

419 return self.partial_kwargs.get("max_retry_delay") 

420 

421 @property 

422 def retry_delay(self) -> datetime.timedelta: 

423 return self.partial_kwargs.get("retry_delay", DEFAULT_RETRY_DELAY) 

424 

425 @property 

426 def retry_exponential_backoff(self) -> bool: 

427 return bool(self.partial_kwargs.get("retry_exponential_backoff")) 

428 

429 @property 

430 def priority_weight(self) -> int: # type: ignore[override] 

431 return self.partial_kwargs.get("priority_weight", DEFAULT_PRIORITY_WEIGHT) 

432 

433 @property 

434 def weight_rule(self) -> int: # type: ignore[override] 

435 return self.partial_kwargs.get("weight_rule", DEFAULT_WEIGHT_RULE) 

436 

437 @property 

438 def sla(self) -> datetime.timedelta | None: 

439 return self.partial_kwargs.get("sla") 

440 

441 @property 

442 def max_active_tis_per_dag(self) -> int | None: 

443 return self.partial_kwargs.get("max_active_tis_per_dag") 

444 

445 @property 

446 def resources(self) -> Resources | None: 

447 return self.partial_kwargs.get("resources") 

448 

449 @property 

450 def on_execute_callback(self) -> None | TaskStateChangeCallback | list[TaskStateChangeCallback]: 

451 return self.partial_kwargs.get("on_execute_callback") 

452 

453 @on_execute_callback.setter 

454 def on_execute_callback(self, value: TaskStateChangeCallback | None) -> None: 

455 self.partial_kwargs["on_execute_callback"] = value 

456 

457 @property 

458 def on_failure_callback(self) -> None | TaskStateChangeCallback | list[TaskStateChangeCallback]: 

459 return self.partial_kwargs.get("on_failure_callback") 

460 

461 @on_failure_callback.setter 

462 def on_failure_callback(self, value: TaskStateChangeCallback | None) -> None: 

463 self.partial_kwargs["on_failure_callback"] = value 

464 

465 @property 

466 def on_retry_callback(self) -> None | TaskStateChangeCallback | list[TaskStateChangeCallback]: 

467 return self.partial_kwargs.get("on_retry_callback") 

468 

469 @on_retry_callback.setter 

470 def on_retry_callback(self, value: TaskStateChangeCallback | None) -> None: 

471 self.partial_kwargs["on_retry_callback"] = value 

472 

473 @property 

474 def on_success_callback(self) -> None | TaskStateChangeCallback | list[TaskStateChangeCallback]: 

475 return self.partial_kwargs.get("on_success_callback") 

476 

477 @on_success_callback.setter 

478 def on_success_callback(self, value: TaskStateChangeCallback | None) -> None: 

479 self.partial_kwargs["on_success_callback"] = value 

480 

481 @property 

482 def run_as_user(self) -> str | None: 

483 return self.partial_kwargs.get("run_as_user") 

484 

485 @property 

486 def executor_config(self) -> dict: 

487 return self.partial_kwargs.get("executor_config", {}) 

488 

489 @property # type: ignore[override] 

490 def inlets(self) -> list[Any]: # type: ignore[override] 

491 return self.partial_kwargs.get("inlets", []) 

492 

493 @inlets.setter 

494 def inlets(self, value: list[Any]) -> None: # type: ignore[override] 

495 self.partial_kwargs["inlets"] = value 

496 

497 @property # type: ignore[override] 

498 def outlets(self) -> list[Any]: # type: ignore[override] 

499 return self.partial_kwargs.get("outlets", []) 

500 

501 @outlets.setter 

502 def outlets(self, value: list[Any]) -> None: # type: ignore[override] 

503 self.partial_kwargs["outlets"] = value 

504 

505 @property 

506 def doc(self) -> str | None: 

507 return self.partial_kwargs.get("doc") 

508 

509 @property 

510 def doc_md(self) -> str | None: 

511 return self.partial_kwargs.get("doc_md") 

512 

513 @property 

514 def doc_json(self) -> str | None: 

515 return self.partial_kwargs.get("doc_json") 

516 

517 @property 

518 def doc_yaml(self) -> str | None: 

519 return self.partial_kwargs.get("doc_yaml") 

520 

521 @property 

522 def doc_rst(self) -> str | None: 

523 return self.partial_kwargs.get("doc_rst") 

524 

525 def get_dag(self) -> DAG | None: 

526 """Implementing Operator.""" 

527 return self.dag 

528 

529 @property 

530 def output(self) -> XComArg: 

531 """Returns reference to XCom pushed by current operator""" 

532 from airflow.models.xcom_arg import XComArg 

533 

534 return XComArg(operator=self) 

535 

536 def serialize_for_task_group(self) -> tuple[DagAttributeTypes, Any]: 

537 """Implementing DAGNode.""" 

538 return DagAttributeTypes.OP, self.task_id 

539 

540 def _expand_mapped_kwargs(self, context: Context, session: Session) -> tuple[Mapping[str, Any], set[int]]: 

541 """Get the kwargs to create the unmapped operator. 

542 

543 This exists because taskflow operators expand against op_kwargs, not the 

544 entire operator kwargs dict. 

545 """ 

546 return self._get_specified_expand_input().resolve(context, session) 

547 

548 def _get_unmap_kwargs(self, mapped_kwargs: Mapping[str, Any], *, strict: bool) -> dict[str, Any]: 

549 """Get init kwargs to unmap the underlying operator class. 

550 

551 :param mapped_kwargs: The dict returned by ``_expand_mapped_kwargs``. 

552 """ 

553 if strict: 

554 prevent_duplicates( 

555 self.partial_kwargs, 

556 mapped_kwargs, 

557 fail_reason="unmappable or already specified", 

558 ) 

559 

560 # If params appears in the mapped kwargs, we need to merge it into the 

561 # partial params, overriding existing keys. 

562 params = copy.copy(self.params) 

563 with contextlib.suppress(KeyError): 

564 params.update(mapped_kwargs["params"]) 

565 

566 # Ordering is significant; mapped kwargs should override partial ones, 

567 # and the specially handled params should be respected. 

568 return { 

569 "task_id": self.task_id, 

570 "dag": self.dag, 

571 "task_group": self.task_group, 

572 "start_date": self.start_date, 

573 "end_date": self.end_date, 

574 **self.partial_kwargs, 

575 **mapped_kwargs, 

576 "params": params, 

577 } 

578 

579 def unmap(self, resolve: None | Mapping[str, Any] | tuple[Context, Session]) -> BaseOperator: 

580 """Get the "normal" Operator after applying the current mapping. 

581 

582 The *resolve* argument is only used if ``operator_class`` is a real 

583 class, i.e. if this operator is not serialized. If ``operator_class`` is 

584 not a class (i.e. this DAG has been deserialized), this returns a 

585 SerializedBaseOperator that "looks like" the actual unmapping result. 

586 

587 If *resolve* is a two-tuple (context, session), the information is used 

588 to resolve the mapped arguments into init arguments. If it is a mapping, 

589 no resolving happens, the mapping directly provides those init arguments 

590 resolved from mapped kwargs. 

591 

592 :meta private: 

593 """ 

594 if isinstance(self.operator_class, type): 

595 if isinstance(resolve, collections.abc.Mapping): 

596 kwargs = resolve 

597 elif resolve is not None: 

598 kwargs, _ = self._expand_mapped_kwargs(*resolve) 

599 else: 

600 raise RuntimeError("cannot unmap a non-serialized operator without context") 

601 kwargs = self._get_unmap_kwargs(kwargs, strict=self._disallow_kwargs_override) 

602 op = self.operator_class(**kwargs, _airflow_from_mapped=True) 

603 # We need to overwrite task_id here because BaseOperator further 

604 # mangles the task_id based on the task hierarchy (namely, group_id 

605 # is prepended, and '__N' appended to deduplicate). This is hacky, 

606 # but better than duplicating the whole mangling logic. 

607 op.task_id = self.task_id 

608 return op 

609 

610 # After a mapped operator is serialized, there's no real way to actually 

611 # unmap it since we've lost access to the underlying operator class. 

612 # This tries its best to simply "forward" all the attributes on this 

613 # mapped operator to a new SerializedBaseOperator instance. 

614 from airflow.serialization.serialized_objects import SerializedBaseOperator 

615 

616 op = SerializedBaseOperator(task_id=self.task_id, params=self.params, _airflow_from_mapped=True) 

617 SerializedBaseOperator.populate_operator(op, self.operator_class) 

618 return op 

619 

620 def _get_specified_expand_input(self) -> ExpandInput: 

621 """Input received from the expand call on the operator.""" 

622 return getattr(self, self._expand_input_attr) 

623 

624 def prepare_for_execution(self) -> MappedOperator: 

625 # Since a mapped operator cannot be used for execution, and an unmapped 

626 # BaseOperator needs to be created later (see render_template_fields), 

627 # we don't need to create a copy of the MappedOperator here. 

628 return self 

629 

630 def iter_mapped_dependencies(self) -> Iterator[Operator]: 

631 """Upstream dependencies that provide XComs used by this task for task mapping.""" 

632 from airflow.models.xcom_arg import XComArg 

633 

634 for operator, _ in XComArg.iter_xcom_references(self._get_specified_expand_input()): 

635 yield operator 

636 

637 @cache 

638 def get_parse_time_mapped_ti_count(self) -> int: 

639 current_count = self._get_specified_expand_input().get_parse_time_mapped_ti_count() 

640 try: 

641 parent_count = super().get_parse_time_mapped_ti_count() 

642 except NotMapped: 

643 return current_count 

644 return parent_count * current_count 

645 

646 def get_mapped_ti_count(self, run_id: str, *, session: Session) -> int: 

647 current_count = self._get_specified_expand_input().get_total_map_length(run_id, session=session) 

648 try: 

649 parent_count = super().get_mapped_ti_count(run_id, session=session) 

650 except NotMapped: 

651 return current_count 

652 return parent_count * current_count 

653 

654 def render_template_fields( 

655 self, 

656 context: Context, 

657 jinja_env: jinja2.Environment | None = None, 

658 ) -> None: 

659 """Template all attributes listed in *self.template_fields*. 

660 

661 This updates *context* to reference the map-expanded task and relevant 

662 information, without modifying the mapped operator. The expanded task 

663 in *context* is then rendered in-place. 

664 

665 :param context: Context dict with values to apply on content. 

666 :param jinja_env: Jinja environment to use for rendering. 

667 """ 

668 if not jinja_env: 

669 jinja_env = self.get_template_env() 

670 

671 # Ideally we'd like to pass in session as an argument to this function, 

672 # but we can't easily change this function signature since operators 

673 # could override this. We can't use @provide_session since it closes and 

674 # expunges everything, which we don't want to do when we are so "deep" 

675 # in the weeds here. We don't close this session for the same reason. 

676 session = settings.Session() 

677 

678 mapped_kwargs, seen_oids = self._expand_mapped_kwargs(context, session) 

679 unmapped_task = self.unmap(mapped_kwargs) 

680 context_update_for_unmapped(context, unmapped_task) 

681 

682 self._do_render_template_fields( 

683 parent=unmapped_task, 

684 template_fields=self.template_fields, 

685 context=context, 

686 jinja_env=jinja_env, 

687 seen_oids=seen_oids, 

688 session=session, 

689 )