Coverage for /pythoncovmergedfiles/medio/medio/src/airflow/airflow/models/mappedoperator.py: 49%

389 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-06-07 06:35 +0000

1# 

2# Licensed to the Apache Software Foundation (ASF) under one 

3# or more contributor license agreements. See the NOTICE file 

4# distributed with this work for additional information 

5# regarding copyright ownership. The ASF licenses this file 

6# to you under the Apache License, Version 2.0 (the 

7# "License"); you may not use this file except in compliance 

8# with the License. You may obtain a copy of the License at 

9# 

10# http://www.apache.org/licenses/LICENSE-2.0 

11# 

12# Unless required by applicable law or agreed to in writing, 

13# software distributed under the License is distributed on an 

14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 

15# KIND, either express or implied. See the License for the 

16# specific language governing permissions and limitations 

17# under the License. 

18from __future__ import annotations 

19 

20import collections 

21import collections.abc 

22import contextlib 

23import copy 

24import datetime 

25import warnings 

26from typing import TYPE_CHECKING, Any, ClassVar, Collection, Iterable, Iterator, Mapping, Sequence, Union 

27 

28import attr 

29import pendulum 

30from sqlalchemy.orm.session import Session 

31 

32from airflow import settings 

33from airflow.compat.functools import cache 

34from airflow.exceptions import AirflowException, UnmappableOperator 

35from airflow.models.abstractoperator import ( 

36 DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST, 

37 DEFAULT_OWNER, 

38 DEFAULT_POOL_SLOTS, 

39 DEFAULT_PRIORITY_WEIGHT, 

40 DEFAULT_QUEUE, 

41 DEFAULT_RETRIES, 

42 DEFAULT_RETRY_DELAY, 

43 DEFAULT_TRIGGER_RULE, 

44 DEFAULT_WAIT_FOR_PAST_DEPENDS_BEFORE_SKIPPING, 

45 DEFAULT_WEIGHT_RULE, 

46 AbstractOperator, 

47 NotMapped, 

48 TaskStateChangeCallback, 

49) 

50from airflow.models.expandinput import ( 

51 DictOfListsExpandInput, 

52 ExpandInput, 

53 ListOfDictsExpandInput, 

54 OperatorExpandArgument, 

55 OperatorExpandKwargsArgument, 

56 is_mappable, 

57) 

58from airflow.models.param import ParamsDict 

59from airflow.models.pool import Pool 

60from airflow.serialization.enums import DagAttributeTypes 

61from airflow.ti_deps.deps.base_ti_dep import BaseTIDep 

62from airflow.ti_deps.deps.mapped_task_expanded import MappedTaskIsExpanded 

63from airflow.typing_compat import Literal 

64from airflow.utils.context import Context, context_update_for_unmapped 

65from airflow.utils.helpers import is_container, prevent_duplicates 

66from airflow.utils.operator_resources import Resources 

67from airflow.utils.trigger_rule import TriggerRule 

68from airflow.utils.types import NOTSET 

69from airflow.utils.xcom import XCOM_RETURN_KEY 

70 

71if TYPE_CHECKING: 

72 import jinja2 # Slow import. 

73 

74 from airflow.models.baseoperator import BaseOperator, BaseOperatorLink 

75 from airflow.models.dag import DAG 

76 from airflow.models.operator import Operator 

77 from airflow.models.xcom_arg import XComArg 

78 from airflow.utils.task_group import TaskGroup 

79 

80ValidationSource = Union[Literal["expand"], Literal["partial"]] 

81 

82 

83def validate_mapping_kwargs(op: type[BaseOperator], func: ValidationSource, value: dict[str, Any]) -> None: 

84 # use a dict so order of args is same as code order 

85 unknown_args = value.copy() 

86 for klass in op.mro(): 

87 init = klass.__init__ # type: ignore[misc] 

88 try: 

89 param_names = init._BaseOperatorMeta__param_names 

90 except AttributeError: 

91 continue 

92 for name in param_names: 

93 value = unknown_args.pop(name, NOTSET) 

94 if func != "expand": 

95 continue 

96 if value is NOTSET: 

97 continue 

98 if is_mappable(value): 

99 continue 

100 type_name = type(value).__name__ 

101 error = f"{op.__name__}.expand() got an unexpected type {type_name!r} for keyword argument {name}" 

102 raise ValueError(error) 

103 if not unknown_args: 

104 return # If we have no args left to check: stop looking at the MRO chain. 

105 

106 if len(unknown_args) == 1: 

107 error = f"an unexpected keyword argument {unknown_args.popitem()[0]!r}" 

108 else: 

109 names = ", ".join(repr(n) for n in unknown_args) 

110 error = f"unexpected keyword arguments {names}" 

111 raise TypeError(f"{op.__name__}.{func}() got {error}") 

112 

113 

114def ensure_xcomarg_return_value(arg: Any) -> None: 

115 from airflow.models.xcom_arg import XComArg 

116 

117 if isinstance(arg, XComArg): 

118 for operator, key in arg.iter_references(): 

119 if key != XCOM_RETURN_KEY: 

120 raise ValueError(f"cannot map over XCom with custom key {key!r} from {operator}") 

121 elif not is_container(arg): 

122 return 

123 elif isinstance(arg, collections.abc.Mapping): 

124 for v in arg.values(): 

125 ensure_xcomarg_return_value(v) 

126 elif isinstance(arg, collections.abc.Iterable): 

127 for v in arg: 

128 ensure_xcomarg_return_value(v) 

129 

130 

131@attr.define(kw_only=True, repr=False) 

132class OperatorPartial: 

133 """An "intermediate state" returned by ``BaseOperator.partial()``. 

134 

135 This only exists at DAG-parsing time; the only intended usage is for the 

136 user to call ``.expand()`` on it at some point (usually in a method chain) to 

137 create a ``MappedOperator`` to add into the DAG. 

138 """ 

139 

140 operator_class: type[BaseOperator] 

141 kwargs: dict[str, Any] 

142 params: ParamsDict | dict 

143 

144 _expand_called: bool = False # Set when expand() is called to ease user debugging. 

145 

146 def __attrs_post_init__(self): 

147 from airflow.operators.subdag import SubDagOperator 

148 

149 if issubclass(self.operator_class, SubDagOperator): 

150 raise TypeError("Mapping over deprecated SubDagOperator is not supported") 

151 validate_mapping_kwargs(self.operator_class, "partial", self.kwargs) 

152 

153 def __repr__(self) -> str: 

154 args = ", ".join(f"{k}={v!r}" for k, v in self.kwargs.items()) 

155 return f"{self.operator_class.__name__}.partial({args})" 

156 

157 def __del__(self): 

158 if not self._expand_called: 

159 try: 

160 task_id = repr(self.kwargs["task_id"]) 

161 except KeyError: 

162 task_id = f"at {hex(id(self))}" 

163 warnings.warn(f"Task {task_id} was never mapped!") 

164 

165 def expand(self, **mapped_kwargs: OperatorExpandArgument) -> MappedOperator: 

166 if not mapped_kwargs: 

167 raise TypeError("no arguments to expand against") 

168 validate_mapping_kwargs(self.operator_class, "expand", mapped_kwargs) 

169 prevent_duplicates(self.kwargs, mapped_kwargs, fail_reason="unmappable or already specified") 

170 # Since the input is already checked at parse time, we can set strict 

171 # to False to skip the checks on execution. 

172 return self._expand(DictOfListsExpandInput(mapped_kwargs), strict=False) 

173 

174 def expand_kwargs(self, kwargs: OperatorExpandKwargsArgument, *, strict: bool = True) -> MappedOperator: 

175 from airflow.models.xcom_arg import XComArg 

176 

177 if isinstance(kwargs, collections.abc.Sequence): 

178 for item in kwargs: 

179 if not isinstance(item, (XComArg, collections.abc.Mapping)): 

180 raise TypeError(f"expected XComArg or list[dict], not {type(kwargs).__name__}") 

181 elif not isinstance(kwargs, XComArg): 

182 raise TypeError(f"expected XComArg or list[dict], not {type(kwargs).__name__}") 

183 return self._expand(ListOfDictsExpandInput(kwargs), strict=strict) 

184 

185 def _expand(self, expand_input: ExpandInput, *, strict: bool) -> MappedOperator: 

186 from airflow.operators.empty import EmptyOperator 

187 

188 self._expand_called = True 

189 ensure_xcomarg_return_value(expand_input.value) 

190 

191 partial_kwargs = self.kwargs.copy() 

192 task_id = partial_kwargs.pop("task_id") 

193 dag = partial_kwargs.pop("dag") 

194 task_group = partial_kwargs.pop("task_group") 

195 start_date = partial_kwargs.pop("start_date") 

196 end_date = partial_kwargs.pop("end_date") 

197 

198 try: 

199 operator_name = self.operator_class.custom_operator_name # type: ignore 

200 except AttributeError: 

201 operator_name = self.operator_class.__name__ 

202 

203 op = MappedOperator( 

204 operator_class=self.operator_class, 

205 expand_input=expand_input, 

206 partial_kwargs=partial_kwargs, 

207 task_id=task_id, 

208 params=self.params, 

209 deps=MappedOperator.deps_for(self.operator_class), 

210 operator_extra_links=self.operator_class.operator_extra_links, 

211 template_ext=self.operator_class.template_ext, 

212 template_fields=self.operator_class.template_fields, 

213 template_fields_renderers=self.operator_class.template_fields_renderers, 

214 ui_color=self.operator_class.ui_color, 

215 ui_fgcolor=self.operator_class.ui_fgcolor, 

216 is_empty=issubclass(self.operator_class, EmptyOperator), 

217 task_module=self.operator_class.__module__, 

218 task_type=self.operator_class.__name__, 

219 operator_name=operator_name, 

220 dag=dag, 

221 task_group=task_group, 

222 start_date=start_date, 

223 end_date=end_date, 

224 disallow_kwargs_override=strict, 

225 # For classic operators, this points to expand_input because kwargs 

226 # to BaseOperator.expand() contribute to operator arguments. 

227 expand_input_attr="expand_input", 

228 ) 

229 return op 

230 

231 

232@attr.define( 

233 kw_only=True, 

234 # Disable custom __getstate__ and __setstate__ generation since it interacts 

235 # badly with Airflow's DAG serialization and pickling. When a mapped task is 

236 # deserialized, subclasses are coerced into MappedOperator, but when it goes 

237 # through DAG pickling, all attributes defined in the subclasses are dropped 

238 # by attrs's custom state management. Since attrs does not do anything too 

239 # special here (the logic is only important for slots=True), we use Python's 

240 # built-in implementation, which works (as proven by good old BaseOperator). 

241 getstate_setstate=False, 

242) 

243class MappedOperator(AbstractOperator): 

244 """Object representing a mapped operator in a DAG.""" 

245 

246 # This attribute serves double purpose. For a "normal" operator instance 

247 # loaded from DAG, this holds the underlying non-mapped operator class that 

248 # can be used to create an unmapped operator for execution. For an operator 

249 # recreated from a serialized DAG, however, this holds the serialized data 

250 # that can be used to unmap this into a SerializedBaseOperator. 

251 operator_class: type[BaseOperator] | dict[str, Any] 

252 

253 expand_input: ExpandInput 

254 partial_kwargs: dict[str, Any] 

255 

256 # Needed for serialization. 

257 task_id: str 

258 params: ParamsDict | dict 

259 deps: frozenset[BaseTIDep] 

260 operator_extra_links: Collection[BaseOperatorLink] 

261 template_ext: Sequence[str] 

262 template_fields: Collection[str] 

263 template_fields_renderers: dict[str, str] 

264 ui_color: str 

265 ui_fgcolor: str 

266 _is_empty: bool 

267 _task_module: str 

268 _task_type: str 

269 _operator_name: str 

270 

271 dag: DAG | None 

272 task_group: TaskGroup | None 

273 start_date: pendulum.DateTime | None 

274 end_date: pendulum.DateTime | None 

275 upstream_task_ids: set[str] = attr.ib(factory=set, init=False) 

276 downstream_task_ids: set[str] = attr.ib(factory=set, init=False) 

277 

278 _disallow_kwargs_override: bool 

279 """Whether execution fails if ``expand_input`` has duplicates to ``partial_kwargs``. 

280 

281 If *False*, values from ``expand_input`` under duplicate keys override those 

282 under corresponding keys in ``partial_kwargs``. 

283 """ 

284 

285 _expand_input_attr: str 

286 """Where to get kwargs to calculate expansion length against. 

287 

288 This should be a name to call ``getattr()`` on. 

289 """ 

290 

291 subdag: None = None # Since we don't support SubDagOperator, this is always None. 

292 supports_lineage: bool = False 

293 is_setup: bool = False 

294 is_teardown: bool = False 

295 on_failure_fail_dagrun: bool = False 

296 

297 HIDE_ATTRS_FROM_UI: ClassVar[frozenset[str]] = AbstractOperator.HIDE_ATTRS_FROM_UI | frozenset( 

298 ( 

299 "parse_time_mapped_ti_count", 

300 "operator_class", 

301 ) 

302 ) 

303 

304 def __hash__(self): 

305 return id(self) 

306 

307 def __repr__(self): 

308 return f"<Mapped({self._task_type}): {self.task_id}>" 

309 

310 def __attrs_post_init__(self): 

311 from airflow.models.xcom_arg import XComArg 

312 

313 if self.get_closest_mapped_task_group() is not None: 

314 raise NotImplementedError("operator expansion in an expanded task group is not yet supported") 

315 

316 if self.task_group: 

317 self.task_group.add(self) 

318 if self.dag: 

319 self.dag.add_task(self) 

320 XComArg.apply_upstream_relationship(self, self.expand_input.value) 

321 for k, v in self.partial_kwargs.items(): 

322 if k in self.template_fields: 

323 XComArg.apply_upstream_relationship(self, v) 

324 if self.partial_kwargs.get("sla") is not None: 

325 raise AirflowException( 

326 f"SLAs are unsupported with mapped tasks. Please set `sla=None` for task " 

327 f"{self.task_id!r}." 

328 ) 

329 

330 @classmethod 

331 @cache 

332 def get_serialized_fields(cls): 

333 # Not using 'cls' here since we only want to serialize base fields. 

334 return frozenset(attr.fields_dict(MappedOperator)) - { 

335 "dag", 

336 "deps", 

337 "expand_input", # This is needed to be able to accept XComArg. 

338 "subdag", 

339 "task_group", 

340 "upstream_task_ids", 

341 "supports_lineage", 

342 "is_setup", 

343 "is_teardown", 

344 "on_failure_fail_dagrun", 

345 } 

346 

347 @staticmethod 

348 @cache 

349 def deps_for(operator_class: type[BaseOperator]) -> frozenset[BaseTIDep]: 

350 operator_deps = operator_class.deps 

351 if not isinstance(operator_deps, collections.abc.Set): 

352 raise UnmappableOperator( 

353 f"'deps' must be a set defined as a class-level variable on {operator_class.__name__}, " 

354 f"not a {type(operator_deps).__name__}" 

355 ) 

356 return operator_deps | {MappedTaskIsExpanded()} 

357 

358 @property 

359 def task_type(self) -> str: 

360 """Implementing Operator.""" 

361 return self._task_type 

362 

363 @property 

364 def operator_name(self) -> str: 

365 return self._operator_name 

366 

367 @property 

368 def inherits_from_empty_operator(self) -> bool: 

369 """Implementing Operator.""" 

370 return self._is_empty 

371 

372 @property 

373 def roots(self) -> Sequence[AbstractOperator]: 

374 """Implementing DAGNode.""" 

375 return [self] 

376 

377 @property 

378 def leaves(self) -> Sequence[AbstractOperator]: 

379 """Implementing DAGNode.""" 

380 return [self] 

381 

382 @property 

383 def owner(self) -> str: # type: ignore[override] 

384 return self.partial_kwargs.get("owner", DEFAULT_OWNER) 

385 

386 @property 

387 def email(self) -> None | str | Iterable[str]: 

388 return self.partial_kwargs.get("email") 

389 

390 @property 

391 def trigger_rule(self) -> TriggerRule: 

392 return self.partial_kwargs.get("trigger_rule", DEFAULT_TRIGGER_RULE) 

393 

394 @property 

395 def depends_on_past(self) -> bool: 

396 return bool(self.partial_kwargs.get("depends_on_past")) 

397 

398 @property 

399 def ignore_first_depends_on_past(self) -> bool: 

400 value = self.partial_kwargs.get("ignore_first_depends_on_past", DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST) 

401 return bool(value) 

402 

403 @property 

404 def wait_for_past_depends_before_skipping(self) -> bool: 

405 value = self.partial_kwargs.get( 

406 "wait_for_past_depends_before_skipping", DEFAULT_WAIT_FOR_PAST_DEPENDS_BEFORE_SKIPPING 

407 ) 

408 return bool(value) 

409 

410 @property 

411 def wait_for_downstream(self) -> bool: 

412 return bool(self.partial_kwargs.get("wait_for_downstream")) 

413 

414 @property 

415 def retries(self) -> int | None: 

416 return self.partial_kwargs.get("retries", DEFAULT_RETRIES) 

417 

418 @property 

419 def queue(self) -> str: 

420 return self.partial_kwargs.get("queue", DEFAULT_QUEUE) 

421 

422 @property 

423 def pool(self) -> str: 

424 return self.partial_kwargs.get("pool", Pool.DEFAULT_POOL_NAME) 

425 

426 @property 

427 def pool_slots(self) -> str | None: 

428 return self.partial_kwargs.get("pool_slots", DEFAULT_POOL_SLOTS) 

429 

430 @property 

431 def execution_timeout(self) -> datetime.timedelta | None: 

432 return self.partial_kwargs.get("execution_timeout") 

433 

434 @property 

435 def max_retry_delay(self) -> datetime.timedelta | None: 

436 return self.partial_kwargs.get("max_retry_delay") 

437 

438 @property 

439 def retry_delay(self) -> datetime.timedelta: 

440 return self.partial_kwargs.get("retry_delay", DEFAULT_RETRY_DELAY) 

441 

442 @property 

443 def retry_exponential_backoff(self) -> bool: 

444 return bool(self.partial_kwargs.get("retry_exponential_backoff")) 

445 

446 @property 

447 def priority_weight(self) -> int: # type: ignore[override] 

448 return self.partial_kwargs.get("priority_weight", DEFAULT_PRIORITY_WEIGHT) 

449 

450 @property 

451 def weight_rule(self) -> int: # type: ignore[override] 

452 return self.partial_kwargs.get("weight_rule", DEFAULT_WEIGHT_RULE) 

453 

454 @property 

455 def sla(self) -> datetime.timedelta | None: 

456 return self.partial_kwargs.get("sla") 

457 

458 @property 

459 def max_active_tis_per_dag(self) -> int | None: 

460 return self.partial_kwargs.get("max_active_tis_per_dag") 

461 

462 @property 

463 def max_active_tis_per_dagrun(self) -> int | None: 

464 return self.partial_kwargs.get("max_active_tis_per_dagrun") 

465 

466 @property 

467 def resources(self) -> Resources | None: 

468 return self.partial_kwargs.get("resources") 

469 

470 @property 

471 def on_execute_callback(self) -> None | TaskStateChangeCallback | list[TaskStateChangeCallback]: 

472 return self.partial_kwargs.get("on_execute_callback") 

473 

474 @on_execute_callback.setter 

475 def on_execute_callback(self, value: TaskStateChangeCallback | None) -> None: 

476 self.partial_kwargs["on_execute_callback"] = value 

477 

478 @property 

479 def on_failure_callback(self) -> None | TaskStateChangeCallback | list[TaskStateChangeCallback]: 

480 return self.partial_kwargs.get("on_failure_callback") 

481 

482 @on_failure_callback.setter 

483 def on_failure_callback(self, value: TaskStateChangeCallback | None) -> None: 

484 self.partial_kwargs["on_failure_callback"] = value 

485 

486 @property 

487 def on_retry_callback(self) -> None | TaskStateChangeCallback | list[TaskStateChangeCallback]: 

488 return self.partial_kwargs.get("on_retry_callback") 

489 

490 @on_retry_callback.setter 

491 def on_retry_callback(self, value: TaskStateChangeCallback | None) -> None: 

492 self.partial_kwargs["on_retry_callback"] = value 

493 

494 @property 

495 def on_success_callback(self) -> None | TaskStateChangeCallback | list[TaskStateChangeCallback]: 

496 return self.partial_kwargs.get("on_success_callback") 

497 

498 @on_success_callback.setter 

499 def on_success_callback(self, value: TaskStateChangeCallback | None) -> None: 

500 self.partial_kwargs["on_success_callback"] = value 

501 

502 @property 

503 def run_as_user(self) -> str | None: 

504 return self.partial_kwargs.get("run_as_user") 

505 

506 @property 

507 def executor_config(self) -> dict: 

508 return self.partial_kwargs.get("executor_config", {}) 

509 

510 @property # type: ignore[override] 

511 def inlets(self) -> list[Any]: # type: ignore[override] 

512 return self.partial_kwargs.get("inlets", []) 

513 

514 @inlets.setter 

515 def inlets(self, value: list[Any]) -> None: # type: ignore[override] 

516 self.partial_kwargs["inlets"] = value 

517 

518 @property # type: ignore[override] 

519 def outlets(self) -> list[Any]: # type: ignore[override] 

520 return self.partial_kwargs.get("outlets", []) 

521 

522 @outlets.setter 

523 def outlets(self, value: list[Any]) -> None: # type: ignore[override] 

524 self.partial_kwargs["outlets"] = value 

525 

526 @property 

527 def doc(self) -> str | None: 

528 return self.partial_kwargs.get("doc") 

529 

530 @property 

531 def doc_md(self) -> str | None: 

532 return self.partial_kwargs.get("doc_md") 

533 

534 @property 

535 def doc_json(self) -> str | None: 

536 return self.partial_kwargs.get("doc_json") 

537 

538 @property 

539 def doc_yaml(self) -> str | None: 

540 return self.partial_kwargs.get("doc_yaml") 

541 

542 @property 

543 def doc_rst(self) -> str | None: 

544 return self.partial_kwargs.get("doc_rst") 

545 

546 def get_dag(self) -> DAG | None: 

547 """Implementing Operator.""" 

548 return self.dag 

549 

550 @property 

551 def output(self) -> XComArg: 

552 """Returns reference to XCom pushed by current operator.""" 

553 from airflow.models.xcom_arg import XComArg 

554 

555 return XComArg(operator=self) 

556 

557 def serialize_for_task_group(self) -> tuple[DagAttributeTypes, Any]: 

558 """Implementing DAGNode.""" 

559 return DagAttributeTypes.OP, self.task_id 

560 

561 def _expand_mapped_kwargs(self, context: Context, session: Session) -> tuple[Mapping[str, Any], set[int]]: 

562 """Get the kwargs to create the unmapped operator. 

563 

564 This exists because taskflow operators expand against op_kwargs, not the 

565 entire operator kwargs dict. 

566 """ 

567 return self._get_specified_expand_input().resolve(context, session) 

568 

569 def _get_unmap_kwargs(self, mapped_kwargs: Mapping[str, Any], *, strict: bool) -> dict[str, Any]: 

570 """Get init kwargs to unmap the underlying operator class. 

571 

572 :param mapped_kwargs: The dict returned by ``_expand_mapped_kwargs``. 

573 """ 

574 if strict: 

575 prevent_duplicates( 

576 self.partial_kwargs, 

577 mapped_kwargs, 

578 fail_reason="unmappable or already specified", 

579 ) 

580 

581 # If params appears in the mapped kwargs, we need to merge it into the 

582 # partial params, overriding existing keys. 

583 params = copy.copy(self.params) 

584 with contextlib.suppress(KeyError): 

585 params.update(mapped_kwargs["params"]) 

586 

587 # Ordering is significant; mapped kwargs should override partial ones, 

588 # and the specially handled params should be respected. 

589 return { 

590 "task_id": self.task_id, 

591 "dag": self.dag, 

592 "task_group": self.task_group, 

593 "start_date": self.start_date, 

594 "end_date": self.end_date, 

595 **self.partial_kwargs, 

596 **mapped_kwargs, 

597 "params": params, 

598 } 

599 

600 def unmap(self, resolve: None | Mapping[str, Any] | tuple[Context, Session]) -> BaseOperator: 

601 """Get the "normal" Operator after applying the current mapping. 

602 

603 The *resolve* argument is only used if ``operator_class`` is a real 

604 class, i.e. if this operator is not serialized. If ``operator_class`` is 

605 not a class (i.e. this DAG has been deserialized), this returns a 

606 SerializedBaseOperator that "looks like" the actual unmapping result. 

607 

608 If *resolve* is a two-tuple (context, session), the information is used 

609 to resolve the mapped arguments into init arguments. If it is a mapping, 

610 no resolving happens, the mapping directly provides those init arguments 

611 resolved from mapped kwargs. 

612 

613 :meta private: 

614 """ 

615 if isinstance(self.operator_class, type): 

616 if isinstance(resolve, collections.abc.Mapping): 

617 kwargs = resolve 

618 elif resolve is not None: 

619 kwargs, _ = self._expand_mapped_kwargs(*resolve) 

620 else: 

621 raise RuntimeError("cannot unmap a non-serialized operator without context") 

622 kwargs = self._get_unmap_kwargs(kwargs, strict=self._disallow_kwargs_override) 

623 op = self.operator_class(**kwargs, _airflow_from_mapped=True) 

624 # We need to overwrite task_id here because BaseOperator further 

625 # mangles the task_id based on the task hierarchy (namely, group_id 

626 # is prepended, and '__N' appended to deduplicate). This is hacky, 

627 # but better than duplicating the whole mangling logic. 

628 op.task_id = self.task_id 

629 return op 

630 

631 # After a mapped operator is serialized, there's no real way to actually 

632 # unmap it since we've lost access to the underlying operator class. 

633 # This tries its best to simply "forward" all the attributes on this 

634 # mapped operator to a new SerializedBaseOperator instance. 

635 from airflow.serialization.serialized_objects import SerializedBaseOperator 

636 

637 op = SerializedBaseOperator(task_id=self.task_id, params=self.params, _airflow_from_mapped=True) 

638 SerializedBaseOperator.populate_operator(op, self.operator_class) 

639 return op 

640 

641 def _get_specified_expand_input(self) -> ExpandInput: 

642 """Input received from the expand call on the operator.""" 

643 return getattr(self, self._expand_input_attr) 

644 

645 def prepare_for_execution(self) -> MappedOperator: 

646 # Since a mapped operator cannot be used for execution, and an unmapped 

647 # BaseOperator needs to be created later (see render_template_fields), 

648 # we don't need to create a copy of the MappedOperator here. 

649 return self 

650 

651 def iter_mapped_dependencies(self) -> Iterator[Operator]: 

652 """Upstream dependencies that provide XComs used by this task for task mapping.""" 

653 from airflow.models.xcom_arg import XComArg 

654 

655 for operator, _ in XComArg.iter_xcom_references(self._get_specified_expand_input()): 

656 yield operator 

657 

658 @cache 

659 def get_parse_time_mapped_ti_count(self) -> int: 

660 current_count = self._get_specified_expand_input().get_parse_time_mapped_ti_count() 

661 try: 

662 parent_count = super().get_parse_time_mapped_ti_count() 

663 except NotMapped: 

664 return current_count 

665 return parent_count * current_count 

666 

667 def get_mapped_ti_count(self, run_id: str, *, session: Session) -> int: 

668 current_count = self._get_specified_expand_input().get_total_map_length(run_id, session=session) 

669 try: 

670 parent_count = super().get_mapped_ti_count(run_id, session=session) 

671 except NotMapped: 

672 return current_count 

673 return parent_count * current_count 

674 

675 def render_template_fields( 

676 self, 

677 context: Context, 

678 jinja_env: jinja2.Environment | None = None, 

679 ) -> None: 

680 """Template all attributes listed in *self.template_fields*. 

681 

682 This updates *context* to reference the map-expanded task and relevant 

683 information, without modifying the mapped operator. The expanded task 

684 in *context* is then rendered in-place. 

685 

686 :param context: Context dict with values to apply on content. 

687 :param jinja_env: Jinja environment to use for rendering. 

688 """ 

689 if not jinja_env: 

690 jinja_env = self.get_template_env() 

691 

692 # Ideally we'd like to pass in session as an argument to this function, 

693 # but we can't easily change this function signature since operators 

694 # could override this. We can't use @provide_session since it closes and 

695 # expunges everything, which we don't want to do when we are so "deep" 

696 # in the weeds here. We don't close this session for the same reason. 

697 session = settings.Session() 

698 

699 mapped_kwargs, seen_oids = self._expand_mapped_kwargs(context, session) 

700 unmapped_task = self.unmap(mapped_kwargs) 

701 context_update_for_unmapped(context, unmapped_task) 

702 

703 # Since the operators that extend `BaseOperator` are not subclasses of 

704 # `MappedOperator`, we need to call `_do_render_template_fields` from 

705 # the unmapped task in order to call the operator method when we override 

706 # it to customize the parsing of nested fields. 

707 unmapped_task._do_render_template_fields( 

708 parent=unmapped_task, 

709 template_fields=self.template_fields, 

710 context=context, 

711 jinja_env=jinja_env, 

712 seen_oids=seen_oids, 

713 session=session, 

714 )