Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/airflow/models/mappedoperator.py: 49%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

491 statements  

1# 

2# Licensed to the Apache Software Foundation (ASF) under one 

3# or more contributor license agreements. See the NOTICE file 

4# distributed with this work for additional information 

5# regarding copyright ownership. The ASF licenses this file 

6# to you under the Apache License, Version 2.0 (the 

7# "License"); you may not use this file except in compliance 

8# with the License. You may obtain a copy of the License at 

9# 

10# http://www.apache.org/licenses/LICENSE-2.0 

11# 

12# Unless required by applicable law or agreed to in writing, 

13# software distributed under the License is distributed on an 

14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 

15# KIND, either express or implied. See the License for the 

16# specific language governing permissions and limitations 

17# under the License. 

18from __future__ import annotations 

19 

20import collections.abc 

21import contextlib 

22import copy 

23import warnings 

24from typing import TYPE_CHECKING, Any, ClassVar, Collection, Iterable, Iterator, Mapping, Sequence, Union 

25 

26import attr 

27import methodtools 

28 

29from airflow.exceptions import AirflowException, UnmappableOperator 

30from airflow.models.abstractoperator import ( 

31 DEFAULT_EXECUTOR, 

32 DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST, 

33 DEFAULT_OWNER, 

34 DEFAULT_POOL_SLOTS, 

35 DEFAULT_PRIORITY_WEIGHT, 

36 DEFAULT_QUEUE, 

37 DEFAULT_RETRIES, 

38 DEFAULT_RETRY_DELAY, 

39 DEFAULT_TRIGGER_RULE, 

40 DEFAULT_WAIT_FOR_PAST_DEPENDS_BEFORE_SKIPPING, 

41 DEFAULT_WEIGHT_RULE, 

42 AbstractOperator, 

43 NotMapped, 

44) 

45from airflow.models.expandinput import ( 

46 DictOfListsExpandInput, 

47 ListOfDictsExpandInput, 

48 is_mappable, 

49) 

50from airflow.models.pool import Pool 

51from airflow.serialization.enums import DagAttributeTypes 

52from airflow.task.priority_strategy import PriorityWeightStrategy, validate_and_load_priority_weight_strategy 

53from airflow.ti_deps.deps.mapped_task_expanded import MappedTaskIsExpanded 

54from airflow.typing_compat import Literal 

55from airflow.utils.context import context_update_for_unmapped 

56from airflow.utils.helpers import is_container, prevent_duplicates 

57from airflow.utils.task_instance_session import get_current_task_instance_session 

58from airflow.utils.types import NOTSET 

59from airflow.utils.xcom import XCOM_RETURN_KEY 

60 

61if TYPE_CHECKING: 

62 import datetime 

63 from typing import List 

64 

65 import jinja2 # Slow import. 

66 import pendulum 

67 from sqlalchemy.orm.session import Session 

68 

69 from airflow.models.abstractoperator import ( 

70 TaskStateChangeCallback, 

71 ) 

72 from airflow.models.baseoperator import BaseOperator 

73 from airflow.models.baseoperatorlink import BaseOperatorLink 

74 from airflow.models.dag import DAG 

75 from airflow.models.expandinput import ( 

76 ExpandInput, 

77 OperatorExpandArgument, 

78 OperatorExpandKwargsArgument, 

79 ) 

80 from airflow.models.operator import Operator 

81 from airflow.models.param import ParamsDict 

82 from airflow.models.xcom_arg import XComArg 

83 from airflow.ti_deps.deps.base_ti_dep import BaseTIDep 

84 from airflow.triggers.base import BaseTrigger 

85 from airflow.utils.context import Context 

86 from airflow.utils.operator_resources import Resources 

87 from airflow.utils.task_group import TaskGroup 

88 from airflow.utils.trigger_rule import TriggerRule 

89 

90 TaskStateChangeCallbackAttrType = Union[None, TaskStateChangeCallback, List[TaskStateChangeCallback]] 

91 

92ValidationSource = Union[Literal["expand"], Literal["partial"]] 

93 

94 

95def validate_mapping_kwargs(op: type[BaseOperator], func: ValidationSource, value: dict[str, Any]) -> None: 

96 # use a dict so order of args is same as code order 

97 unknown_args = value.copy() 

98 for klass in op.mro(): 

99 init = klass.__init__ # type: ignore[misc] 

100 try: 

101 param_names = init._BaseOperatorMeta__param_names 

102 except AttributeError: 

103 continue 

104 for name in param_names: 

105 value = unknown_args.pop(name, NOTSET) 

106 if func != "expand": 

107 continue 

108 if value is NOTSET: 

109 continue 

110 if is_mappable(value): 

111 continue 

112 type_name = type(value).__name__ 

113 error = f"{op.__name__}.expand() got an unexpected type {type_name!r} for keyword argument {name}" 

114 raise ValueError(error) 

115 if not unknown_args: 

116 return # If we have no args left to check: stop looking at the MRO chain. 

117 

118 if len(unknown_args) == 1: 

119 error = f"an unexpected keyword argument {unknown_args.popitem()[0]!r}" 

120 else: 

121 names = ", ".join(repr(n) for n in unknown_args) 

122 error = f"unexpected keyword arguments {names}" 

123 raise TypeError(f"{op.__name__}.{func}() got {error}") 

124 

125 

126def ensure_xcomarg_return_value(arg: Any) -> None: 

127 from airflow.models.xcom_arg import XComArg 

128 

129 if isinstance(arg, XComArg): 

130 for operator, key in arg.iter_references(): 

131 if key != XCOM_RETURN_KEY: 

132 raise ValueError(f"cannot map over XCom with custom key {key!r} from {operator}") 

133 elif not is_container(arg): 

134 return 

135 elif isinstance(arg, collections.abc.Mapping): 

136 for v in arg.values(): 

137 ensure_xcomarg_return_value(v) 

138 elif isinstance(arg, collections.abc.Iterable): 

139 for v in arg: 

140 ensure_xcomarg_return_value(v) 

141 

142 

143@attr.define(kw_only=True, repr=False) 

144class OperatorPartial: 

145 """An "intermediate state" returned by ``BaseOperator.partial()``. 

146 

147 This only exists at DAG-parsing time; the only intended usage is for the 

148 user to call ``.expand()`` on it at some point (usually in a method chain) to 

149 create a ``MappedOperator`` to add into the DAG. 

150 """ 

151 

152 operator_class: type[BaseOperator] 

153 kwargs: dict[str, Any] 

154 params: ParamsDict | dict 

155 

156 _expand_called: bool = False # Set when expand() is called to ease user debugging. 

157 

158 def __attrs_post_init__(self): 

159 from airflow.operators.subdag import SubDagOperator 

160 

161 if issubclass(self.operator_class, SubDagOperator): 

162 raise TypeError("Mapping over deprecated SubDagOperator is not supported") 

163 validate_mapping_kwargs(self.operator_class, "partial", self.kwargs) 

164 

165 def __repr__(self) -> str: 

166 args = ", ".join(f"{k}={v!r}" for k, v in self.kwargs.items()) 

167 return f"{self.operator_class.__name__}.partial({args})" 

168 

169 def __del__(self): 

170 if not self._expand_called: 

171 try: 

172 task_id = repr(self.kwargs["task_id"]) 

173 except KeyError: 

174 task_id = f"at {hex(id(self))}" 

175 warnings.warn(f"Task {task_id} was never mapped!", category=UserWarning, stacklevel=1) 

176 

177 def expand(self, **mapped_kwargs: OperatorExpandArgument) -> MappedOperator: 

178 if not mapped_kwargs: 

179 raise TypeError("no arguments to expand against") 

180 validate_mapping_kwargs(self.operator_class, "expand", mapped_kwargs) 

181 prevent_duplicates(self.kwargs, mapped_kwargs, fail_reason="unmappable or already specified") 

182 # Since the input is already checked at parse time, we can set strict 

183 # to False to skip the checks on execution. 

184 return self._expand(DictOfListsExpandInput(mapped_kwargs), strict=False) 

185 

186 def expand_kwargs(self, kwargs: OperatorExpandKwargsArgument, *, strict: bool = True) -> MappedOperator: 

187 from airflow.models.xcom_arg import XComArg 

188 

189 if isinstance(kwargs, collections.abc.Sequence): 

190 for item in kwargs: 

191 if not isinstance(item, (XComArg, collections.abc.Mapping)): 

192 raise TypeError(f"expected XComArg or list[dict], not {type(kwargs).__name__}") 

193 elif not isinstance(kwargs, XComArg): 

194 raise TypeError(f"expected XComArg or list[dict], not {type(kwargs).__name__}") 

195 return self._expand(ListOfDictsExpandInput(kwargs), strict=strict) 

196 

197 def _expand(self, expand_input: ExpandInput, *, strict: bool) -> MappedOperator: 

198 from airflow.operators.empty import EmptyOperator 

199 

200 self._expand_called = True 

201 ensure_xcomarg_return_value(expand_input.value) 

202 

203 partial_kwargs = self.kwargs.copy() 

204 task_id = partial_kwargs.pop("task_id") 

205 dag = partial_kwargs.pop("dag") 

206 task_group = partial_kwargs.pop("task_group") 

207 start_date = partial_kwargs.pop("start_date") 

208 end_date = partial_kwargs.pop("end_date") 

209 

210 try: 

211 operator_name = self.operator_class.custom_operator_name # type: ignore 

212 except AttributeError: 

213 operator_name = self.operator_class.__name__ 

214 

215 op = MappedOperator( 

216 operator_class=self.operator_class, 

217 expand_input=expand_input, 

218 partial_kwargs=partial_kwargs, 

219 task_id=task_id, 

220 params=self.params, 

221 deps=MappedOperator.deps_for(self.operator_class), 

222 operator_extra_links=self.operator_class.operator_extra_links, 

223 template_ext=self.operator_class.template_ext, 

224 template_fields=self.operator_class.template_fields, 

225 template_fields_renderers=self.operator_class.template_fields_renderers, 

226 ui_color=self.operator_class.ui_color, 

227 ui_fgcolor=self.operator_class.ui_fgcolor, 

228 is_empty=issubclass(self.operator_class, EmptyOperator), 

229 task_module=self.operator_class.__module__, 

230 task_type=self.operator_class.__name__, 

231 operator_name=operator_name, 

232 dag=dag, 

233 task_group=task_group, 

234 start_date=start_date, 

235 end_date=end_date, 

236 disallow_kwargs_override=strict, 

237 # For classic operators, this points to expand_input because kwargs 

238 # to BaseOperator.expand() contribute to operator arguments. 

239 expand_input_attr="expand_input", 

240 start_trigger=self.operator_class.start_trigger, 

241 next_method=self.operator_class.next_method, 

242 ) 

243 return op 

244 

245 

246@attr.define( 

247 kw_only=True, 

248 # Disable custom __getstate__ and __setstate__ generation since it interacts 

249 # badly with Airflow's DAG serialization and pickling. When a mapped task is 

250 # deserialized, subclasses are coerced into MappedOperator, but when it goes 

251 # through DAG pickling, all attributes defined in the subclasses are dropped 

252 # by attrs's custom state management. Since attrs does not do anything too 

253 # special here (the logic is only important for slots=True), we use Python's 

254 # built-in implementation, which works (as proven by good old BaseOperator). 

255 getstate_setstate=False, 

256) 

257class MappedOperator(AbstractOperator): 

258 """Object representing a mapped operator in a DAG.""" 

259 

260 # This attribute serves double purpose. For a "normal" operator instance 

261 # loaded from DAG, this holds the underlying non-mapped operator class that 

262 # can be used to create an unmapped operator for execution. For an operator 

263 # recreated from a serialized DAG, however, this holds the serialized data 

264 # that can be used to unmap this into a SerializedBaseOperator. 

265 operator_class: type[BaseOperator] | dict[str, Any] 

266 

267 expand_input: ExpandInput 

268 partial_kwargs: dict[str, Any] 

269 

270 # Needed for serialization. 

271 task_id: str 

272 params: ParamsDict | dict 

273 deps: frozenset[BaseTIDep] 

274 operator_extra_links: Collection[BaseOperatorLink] 

275 template_ext: Sequence[str] 

276 template_fields: Collection[str] 

277 template_fields_renderers: dict[str, str] 

278 ui_color: str 

279 ui_fgcolor: str 

280 _is_empty: bool 

281 _task_module: str 

282 _task_type: str 

283 _operator_name: str 

284 start_trigger: BaseTrigger | None 

285 next_method: str | None 

286 _needs_expansion: bool = True 

287 

288 dag: DAG | None 

289 task_group: TaskGroup | None 

290 start_date: pendulum.DateTime | None 

291 end_date: pendulum.DateTime | None 

292 upstream_task_ids: set[str] = attr.ib(factory=set, init=False) 

293 downstream_task_ids: set[str] = attr.ib(factory=set, init=False) 

294 

295 _disallow_kwargs_override: bool 

296 """Whether execution fails if ``expand_input`` has duplicates to ``partial_kwargs``. 

297 

298 If *False*, values from ``expand_input`` under duplicate keys override those 

299 under corresponding keys in ``partial_kwargs``. 

300 """ 

301 

302 _expand_input_attr: str 

303 """Where to get kwargs to calculate expansion length against. 

304 

305 This should be a name to call ``getattr()`` on. 

306 """ 

307 

308 subdag: None = None # Since we don't support SubDagOperator, this is always None. 

309 supports_lineage: bool = False 

310 

311 HIDE_ATTRS_FROM_UI: ClassVar[frozenset[str]] = AbstractOperator.HIDE_ATTRS_FROM_UI | frozenset( 

312 ( 

313 "parse_time_mapped_ti_count", 

314 "operator_class", 

315 "start_trigger", 

316 "next_method", 

317 ) 

318 ) 

319 

320 def __hash__(self): 

321 return id(self) 

322 

323 def __repr__(self): 

324 return f"<Mapped({self._task_type}): {self.task_id}>" 

325 

326 def __attrs_post_init__(self): 

327 from airflow.models.xcom_arg import XComArg 

328 

329 if self.get_closest_mapped_task_group() is not None: 

330 raise NotImplementedError("operator expansion in an expanded task group is not yet supported") 

331 

332 if self.task_group: 

333 self.task_group.add(self) 

334 if self.dag: 

335 self.dag.add_task(self) 

336 XComArg.apply_upstream_relationship(self, self.expand_input.value) 

337 for k, v in self.partial_kwargs.items(): 

338 if k in self.template_fields: 

339 XComArg.apply_upstream_relationship(self, v) 

340 if self.partial_kwargs.get("sla") is not None: 

341 raise AirflowException( 

342 f"SLAs are unsupported with mapped tasks. Please set `sla=None` for task " 

343 f"{self.task_id!r}." 

344 ) 

345 

346 @methodtools.lru_cache(maxsize=None) 

347 @classmethod 

348 def get_serialized_fields(cls): 

349 # Not using 'cls' here since we only want to serialize base fields. 

350 return frozenset(attr.fields_dict(MappedOperator)) - { 

351 "dag", 

352 "deps", 

353 "expand_input", # This is needed to be able to accept XComArg. 

354 "subdag", 

355 "task_group", 

356 "upstream_task_ids", 

357 "supports_lineage", 

358 "_is_setup", 

359 "_is_teardown", 

360 "_on_failure_fail_dagrun", 

361 } 

362 

363 @methodtools.lru_cache(maxsize=None) 

364 @staticmethod 

365 def deps_for(operator_class: type[BaseOperator]) -> frozenset[BaseTIDep]: 

366 operator_deps = operator_class.deps 

367 if not isinstance(operator_deps, collections.abc.Set): 

368 raise UnmappableOperator( 

369 f"'deps' must be a set defined as a class-level variable on {operator_class.__name__}, " 

370 f"not a {type(operator_deps).__name__}" 

371 ) 

372 return operator_deps | {MappedTaskIsExpanded()} 

373 

374 @property 

375 def task_type(self) -> str: 

376 """Implementing Operator.""" 

377 return self._task_type 

378 

379 @property 

380 def operator_name(self) -> str: 

381 return self._operator_name 

382 

383 @property 

384 def inherits_from_empty_operator(self) -> bool: 

385 """Implementing Operator.""" 

386 return self._is_empty 

387 

388 @property 

389 def roots(self) -> Sequence[AbstractOperator]: 

390 """Implementing DAGNode.""" 

391 return [self] 

392 

393 @property 

394 def leaves(self) -> Sequence[AbstractOperator]: 

395 """Implementing DAGNode.""" 

396 return [self] 

397 

398 @property 

399 def task_display_name(self) -> str: 

400 return self.partial_kwargs.get("task_display_name") or self.task_id 

401 

402 @property 

403 def owner(self) -> str: # type: ignore[override] 

404 return self.partial_kwargs.get("owner", DEFAULT_OWNER) 

405 

406 @property 

407 def email(self) -> None | str | Iterable[str]: 

408 return self.partial_kwargs.get("email") 

409 

410 @property 

411 def map_index_template(self) -> None | str: 

412 return self.partial_kwargs.get("map_index_template") 

413 

414 @map_index_template.setter 

415 def map_index_template(self, value: str | None) -> None: 

416 self.partial_kwargs["map_index_template"] = value 

417 

418 @property 

419 def trigger_rule(self) -> TriggerRule: 

420 return self.partial_kwargs.get("trigger_rule", DEFAULT_TRIGGER_RULE) 

421 

422 @trigger_rule.setter 

423 def trigger_rule(self, value): 

424 self.partial_kwargs["trigger_rule"] = value 

425 

426 @property 

427 def is_setup(self) -> bool: 

428 return bool(self.partial_kwargs.get("is_setup")) 

429 

430 @is_setup.setter 

431 def is_setup(self, value: bool) -> None: 

432 self.partial_kwargs["is_setup"] = value 

433 

434 @property 

435 def is_teardown(self) -> bool: 

436 return bool(self.partial_kwargs.get("is_teardown")) 

437 

438 @is_teardown.setter 

439 def is_teardown(self, value: bool) -> None: 

440 self.partial_kwargs["is_teardown"] = value 

441 

442 @property 

443 def depends_on_past(self) -> bool: 

444 return bool(self.partial_kwargs.get("depends_on_past")) 

445 

446 @depends_on_past.setter 

447 def depends_on_past(self, value: bool) -> None: 

448 self.partial_kwargs["depends_on_past"] = value 

449 

450 @property 

451 def ignore_first_depends_on_past(self) -> bool: 

452 value = self.partial_kwargs.get("ignore_first_depends_on_past", DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST) 

453 return bool(value) 

454 

455 @ignore_first_depends_on_past.setter 

456 def ignore_first_depends_on_past(self, value: bool) -> None: 

457 self.partial_kwargs["ignore_first_depends_on_past"] = value 

458 

459 @property 

460 def wait_for_past_depends_before_skipping(self) -> bool: 

461 value = self.partial_kwargs.get( 

462 "wait_for_past_depends_before_skipping", DEFAULT_WAIT_FOR_PAST_DEPENDS_BEFORE_SKIPPING 

463 ) 

464 return bool(value) 

465 

466 @wait_for_past_depends_before_skipping.setter 

467 def wait_for_past_depends_before_skipping(self, value: bool) -> None: 

468 self.partial_kwargs["wait_for_past_depends_before_skipping"] = value 

469 

470 @property 

471 def wait_for_downstream(self) -> bool: 

472 return bool(self.partial_kwargs.get("wait_for_downstream")) 

473 

474 @wait_for_downstream.setter 

475 def wait_for_downstream(self, value: bool) -> None: 

476 self.partial_kwargs["wait_for_downstream"] = value 

477 

478 @property 

479 def retries(self) -> int: 

480 return self.partial_kwargs.get("retries", DEFAULT_RETRIES) 

481 

482 @retries.setter 

483 def retries(self, value: int) -> None: 

484 self.partial_kwargs["retries"] = value 

485 

486 @property 

487 def queue(self) -> str: 

488 return self.partial_kwargs.get("queue", DEFAULT_QUEUE) 

489 

490 @queue.setter 

491 def queue(self, value: str) -> None: 

492 self.partial_kwargs["queue"] = value 

493 

494 @property 

495 def pool(self) -> str: 

496 return self.partial_kwargs.get("pool", Pool.DEFAULT_POOL_NAME) 

497 

498 @pool.setter 

499 def pool(self, value: str) -> None: 

500 self.partial_kwargs["pool"] = value 

501 

502 @property 

503 def pool_slots(self) -> int: 

504 return self.partial_kwargs.get("pool_slots", DEFAULT_POOL_SLOTS) 

505 

506 @pool_slots.setter 

507 def pool_slots(self, value: int) -> None: 

508 self.partial_kwargs["pool_slots"] = value 

509 

510 @property 

511 def execution_timeout(self) -> datetime.timedelta | None: 

512 return self.partial_kwargs.get("execution_timeout") 

513 

514 @execution_timeout.setter 

515 def execution_timeout(self, value: datetime.timedelta | None) -> None: 

516 self.partial_kwargs["execution_timeout"] = value 

517 

518 @property 

519 def max_retry_delay(self) -> datetime.timedelta | None: 

520 return self.partial_kwargs.get("max_retry_delay") 

521 

522 @max_retry_delay.setter 

523 def max_retry_delay(self, value: datetime.timedelta | None) -> None: 

524 self.partial_kwargs["max_retry_delay"] = value 

525 

526 @property 

527 def retry_delay(self) -> datetime.timedelta: 

528 return self.partial_kwargs.get("retry_delay", DEFAULT_RETRY_DELAY) 

529 

530 @retry_delay.setter 

531 def retry_delay(self, value: datetime.timedelta) -> None: 

532 self.partial_kwargs["retry_delay"] = value 

533 

534 @property 

535 def retry_exponential_backoff(self) -> bool: 

536 return bool(self.partial_kwargs.get("retry_exponential_backoff")) 

537 

538 @retry_exponential_backoff.setter 

539 def retry_exponential_backoff(self, value: bool) -> None: 

540 self.partial_kwargs["retry_exponential_backoff"] = value 

541 

542 @property 

543 def priority_weight(self) -> int: # type: ignore[override] 

544 return self.partial_kwargs.get("priority_weight", DEFAULT_PRIORITY_WEIGHT) 

545 

546 @priority_weight.setter 

547 def priority_weight(self, value: int) -> None: 

548 self.partial_kwargs["priority_weight"] = value 

549 

550 @property 

551 def weight_rule(self) -> PriorityWeightStrategy: # type: ignore[override] 

552 return validate_and_load_priority_weight_strategy( 

553 self.partial_kwargs.get("weight_rule", DEFAULT_WEIGHT_RULE) 

554 ) 

555 

556 @weight_rule.setter 

557 def weight_rule(self, value: str | PriorityWeightStrategy) -> None: 

558 self.partial_kwargs["weight_rule"] = validate_and_load_priority_weight_strategy(value) 

559 

560 @property 

561 def sla(self) -> datetime.timedelta | None: 

562 return self.partial_kwargs.get("sla") 

563 

564 @sla.setter 

565 def sla(self, value: datetime.timedelta | None) -> None: 

566 self.partial_kwargs["sla"] = value 

567 

568 @property 

569 def max_active_tis_per_dag(self) -> int | None: 

570 return self.partial_kwargs.get("max_active_tis_per_dag") 

571 

572 @max_active_tis_per_dag.setter 

573 def max_active_tis_per_dag(self, value: int | None) -> None: 

574 self.partial_kwargs["max_active_tis_per_dag"] = value 

575 

576 @property 

577 def max_active_tis_per_dagrun(self) -> int | None: 

578 return self.partial_kwargs.get("max_active_tis_per_dagrun") 

579 

580 @max_active_tis_per_dagrun.setter 

581 def max_active_tis_per_dagrun(self, value: int | None) -> None: 

582 self.partial_kwargs["max_active_tis_per_dagrun"] = value 

583 

584 @property 

585 def resources(self) -> Resources | None: 

586 return self.partial_kwargs.get("resources") 

587 

588 @property 

589 def on_execute_callback(self) -> TaskStateChangeCallbackAttrType: 

590 return self.partial_kwargs.get("on_execute_callback") 

591 

592 @on_execute_callback.setter 

593 def on_execute_callback(self, value: TaskStateChangeCallbackAttrType) -> None: 

594 self.partial_kwargs["on_execute_callback"] = value 

595 

596 @property 

597 def on_failure_callback(self) -> TaskStateChangeCallbackAttrType: 

598 return self.partial_kwargs.get("on_failure_callback") 

599 

600 @on_failure_callback.setter 

601 def on_failure_callback(self, value: TaskStateChangeCallbackAttrType) -> None: 

602 self.partial_kwargs["on_failure_callback"] = value 

603 

604 @property 

605 def on_retry_callback(self) -> TaskStateChangeCallbackAttrType: 

606 return self.partial_kwargs.get("on_retry_callback") 

607 

608 @on_retry_callback.setter 

609 def on_retry_callback(self, value: TaskStateChangeCallbackAttrType) -> None: 

610 self.partial_kwargs["on_retry_callback"] = value 

611 

612 @property 

613 def on_success_callback(self) -> TaskStateChangeCallbackAttrType: 

614 return self.partial_kwargs.get("on_success_callback") 

615 

616 @on_success_callback.setter 

617 def on_success_callback(self, value: TaskStateChangeCallbackAttrType) -> None: 

618 self.partial_kwargs["on_success_callback"] = value 

619 

620 @property 

621 def on_skipped_callback(self) -> TaskStateChangeCallbackAttrType: 

622 return self.partial_kwargs.get("on_skipped_callback") 

623 

624 @on_skipped_callback.setter 

625 def on_skipped_callback(self, value: TaskStateChangeCallbackAttrType) -> None: 

626 self.partial_kwargs["on_skipped_callback"] = value 

627 

628 @property 

629 def run_as_user(self) -> str | None: 

630 return self.partial_kwargs.get("run_as_user") 

631 

632 @property 

633 def executor(self) -> str | None: 

634 return self.partial_kwargs.get("executor", DEFAULT_EXECUTOR) 

635 

636 @property 

637 def executor_config(self) -> dict: 

638 return self.partial_kwargs.get("executor_config", {}) 

639 

640 @property # type: ignore[override] 

641 def inlets(self) -> list[Any]: # type: ignore[override] 

642 return self.partial_kwargs.get("inlets", []) 

643 

644 @inlets.setter 

645 def inlets(self, value: list[Any]) -> None: # type: ignore[override] 

646 self.partial_kwargs["inlets"] = value 

647 

648 @property # type: ignore[override] 

649 def outlets(self) -> list[Any]: # type: ignore[override] 

650 return self.partial_kwargs.get("outlets", []) 

651 

652 @outlets.setter 

653 def outlets(self, value: list[Any]) -> None: # type: ignore[override] 

654 self.partial_kwargs["outlets"] = value 

655 

656 @property 

657 def doc(self) -> str | None: 

658 return self.partial_kwargs.get("doc") 

659 

660 @property 

661 def doc_md(self) -> str | None: 

662 return self.partial_kwargs.get("doc_md") 

663 

664 @property 

665 def doc_json(self) -> str | None: 

666 return self.partial_kwargs.get("doc_json") 

667 

668 @property 

669 def doc_yaml(self) -> str | None: 

670 return self.partial_kwargs.get("doc_yaml") 

671 

672 @property 

673 def doc_rst(self) -> str | None: 

674 return self.partial_kwargs.get("doc_rst") 

675 

676 @property 

677 def allow_nested_operators(self) -> bool: 

678 return bool(self.partial_kwargs.get("allow_nested_operators")) 

679 

680 def get_dag(self) -> DAG | None: 

681 """Implement Operator.""" 

682 return self.dag 

683 

684 @property 

685 def output(self) -> XComArg: 

686 """Return reference to XCom pushed by current operator.""" 

687 from airflow.models.xcom_arg import XComArg 

688 

689 return XComArg(operator=self) 

690 

691 def serialize_for_task_group(self) -> tuple[DagAttributeTypes, Any]: 

692 """Implement DAGNode.""" 

693 return DagAttributeTypes.OP, self.task_id 

694 

695 def _expand_mapped_kwargs(self, context: Context, session: Session) -> tuple[Mapping[str, Any], set[int]]: 

696 """Get the kwargs to create the unmapped operator. 

697 

698 This exists because taskflow operators expand against op_kwargs, not the 

699 entire operator kwargs dict. 

700 """ 

701 return self._get_specified_expand_input().resolve(context, session) 

702 

703 def _get_unmap_kwargs(self, mapped_kwargs: Mapping[str, Any], *, strict: bool) -> dict[str, Any]: 

704 """Get init kwargs to unmap the underlying operator class. 

705 

706 :param mapped_kwargs: The dict returned by ``_expand_mapped_kwargs``. 

707 """ 

708 if strict: 

709 prevent_duplicates( 

710 self.partial_kwargs, 

711 mapped_kwargs, 

712 fail_reason="unmappable or already specified", 

713 ) 

714 

715 # If params appears in the mapped kwargs, we need to merge it into the 

716 # partial params, overriding existing keys. 

717 params = copy.copy(self.params) 

718 with contextlib.suppress(KeyError): 

719 params.update(mapped_kwargs["params"]) 

720 

721 # Ordering is significant; mapped kwargs should override partial ones, 

722 # and the specially handled params should be respected. 

723 return { 

724 "task_id": self.task_id, 

725 "dag": self.dag, 

726 "task_group": self.task_group, 

727 "start_date": self.start_date, 

728 "end_date": self.end_date, 

729 **self.partial_kwargs, 

730 **mapped_kwargs, 

731 "params": params, 

732 } 

733 

734 def unmap(self, resolve: None | Mapping[str, Any] | tuple[Context, Session]) -> BaseOperator: 

735 """Get the "normal" Operator after applying the current mapping. 

736 

737 The *resolve* argument is only used if ``operator_class`` is a real 

738 class, i.e. if this operator is not serialized. If ``operator_class`` is 

739 not a class (i.e. this DAG has been deserialized), this returns a 

740 SerializedBaseOperator that "looks like" the actual unmapping result. 

741 

742 If *resolve* is a two-tuple (context, session), the information is used 

743 to resolve the mapped arguments into init arguments. If it is a mapping, 

744 no resolving happens, the mapping directly provides those init arguments 

745 resolved from mapped kwargs. 

746 

747 :meta private: 

748 """ 

749 if isinstance(self.operator_class, type): 

750 if isinstance(resolve, collections.abc.Mapping): 

751 kwargs = resolve 

752 elif resolve is not None: 

753 kwargs, _ = self._expand_mapped_kwargs(*resolve) 

754 else: 

755 raise RuntimeError("cannot unmap a non-serialized operator without context") 

756 kwargs = self._get_unmap_kwargs(kwargs, strict=self._disallow_kwargs_override) 

757 is_setup = kwargs.pop("is_setup", False) 

758 is_teardown = kwargs.pop("is_teardown", False) 

759 on_failure_fail_dagrun = kwargs.pop("on_failure_fail_dagrun", False) 

760 op = self.operator_class(**kwargs, _airflow_from_mapped=True) 

761 # We need to overwrite task_id here because BaseOperator further 

762 # mangles the task_id based on the task hierarchy (namely, group_id 

763 # is prepended, and '__N' appended to deduplicate). This is hacky, 

764 # but better than duplicating the whole mangling logic. 

765 op.task_id = self.task_id 

766 op.is_setup = is_setup 

767 op.is_teardown = is_teardown 

768 op.on_failure_fail_dagrun = on_failure_fail_dagrun 

769 return op 

770 

771 # After a mapped operator is serialized, there's no real way to actually 

772 # unmap it since we've lost access to the underlying operator class. 

773 # This tries its best to simply "forward" all the attributes on this 

774 # mapped operator to a new SerializedBaseOperator instance. 

775 from airflow.serialization.serialized_objects import SerializedBaseOperator 

776 

777 op = SerializedBaseOperator(task_id=self.task_id, params=self.params, _airflow_from_mapped=True) 

778 SerializedBaseOperator.populate_operator(op, self.operator_class) 

779 if self.dag is not None: # For Mypy; we only serialize tasks in a DAG so the check always satisfies. 

780 SerializedBaseOperator.set_task_dag_references(op, self.dag) 

781 return op 

782 

783 def _get_specified_expand_input(self) -> ExpandInput: 

784 """Input received from the expand call on the operator.""" 

785 return getattr(self, self._expand_input_attr) 

786 

787 def prepare_for_execution(self) -> MappedOperator: 

788 # Since a mapped operator cannot be used for execution, and an unmapped 

789 # BaseOperator needs to be created later (see render_template_fields), 

790 # we don't need to create a copy of the MappedOperator here. 

791 return self 

792 

793 def iter_mapped_dependencies(self) -> Iterator[Operator]: 

794 """Upstream dependencies that provide XComs used by this task for task mapping.""" 

795 from airflow.models.xcom_arg import XComArg 

796 

797 for operator, _ in XComArg.iter_xcom_references(self._get_specified_expand_input()): 

798 yield operator 

799 

800 @methodtools.lru_cache(maxsize=None) 

801 def get_parse_time_mapped_ti_count(self) -> int: 

802 current_count = self._get_specified_expand_input().get_parse_time_mapped_ti_count() 

803 try: 

804 parent_count = super().get_parse_time_mapped_ti_count() 

805 except NotMapped: 

806 return current_count 

807 return parent_count * current_count 

808 

809 def get_mapped_ti_count(self, run_id: str, *, session: Session) -> int: 

810 current_count = self._get_specified_expand_input().get_total_map_length(run_id, session=session) 

811 try: 

812 parent_count = super().get_mapped_ti_count(run_id, session=session) 

813 except NotMapped: 

814 return current_count 

815 return parent_count * current_count 

816 

817 def render_template_fields( 

818 self, 

819 context: Context, 

820 jinja_env: jinja2.Environment | None = None, 

821 ) -> None: 

822 """Template all attributes listed in *self.template_fields*. 

823 

824 This updates *context* to reference the map-expanded task and relevant 

825 information, without modifying the mapped operator. The expanded task 

826 in *context* is then rendered in-place. 

827 

828 :param context: Context dict with values to apply on content. 

829 :param jinja_env: Jinja environment to use for rendering. 

830 """ 

831 if not jinja_env: 

832 jinja_env = self.get_template_env() 

833 

834 # We retrieve the session here, stored by _run_raw_task in set_current_task_session 

835 # context manager - we cannot pass the session via @provide_session because the signature 

836 # of render_template_fields is defined by BaseOperator and there are already many subclasses 

837 # overriding it, so changing the signature is not an option. However render_template_fields is 

838 # always executed within "_run_raw_task" so we make sure that _run_raw_task uses the 

839 # set_current_task_session context manager to store the session in the current task. 

840 session = get_current_task_instance_session() 

841 

842 mapped_kwargs, seen_oids = self._expand_mapped_kwargs(context, session) 

843 unmapped_task = self.unmap(mapped_kwargs) 

844 context_update_for_unmapped(context, unmapped_task) 

845 

846 # Since the operators that extend `BaseOperator` are not subclasses of 

847 # `MappedOperator`, we need to call `_do_render_template_fields` from 

848 # the unmapped task in order to call the operator method when we override 

849 # it to customize the parsing of nested fields. 

850 unmapped_task._do_render_template_fields( 

851 parent=unmapped_task, 

852 template_fields=self.template_fields, 

853 context=context, 

854 jinja_env=jinja_env, 

855 seen_oids=seen_oids, 

856 )