Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/airflow/sdk/definitions/mappedoperator.py: 49%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

501 statements  

1# 

2# Licensed to the Apache Software Foundation (ASF) under one 

3# or more contributor license agreements. See the NOTICE file 

4# distributed with this work for additional information 

5# regarding copyright ownership. The ASF licenses this file 

6# to you under the Apache License, Version 2.0 (the 

7# "License"); you may not use this file except in compliance 

8# with the License. You may obtain a copy of the License at 

9# 

10# http://www.apache.org/licenses/LICENSE-2.0 

11# 

12# Unless required by applicable law or agreed to in writing, 

13# software distributed under the License is distributed on an 

14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 

15# KIND, either express or implied. See the License for the 

16# specific language governing permissions and limitations 

17# under the License. 

18from __future__ import annotations 

19 

20import contextlib 

21import copy 

22import warnings 

23from collections.abc import Collection, Iterable, Iterator, Mapping, Sequence 

24from typing import TYPE_CHECKING, Any, ClassVar, Literal, TypeGuard 

25 

26import attrs 

27import methodtools 

28from lazy_object_proxy import Proxy 

29 

30from airflow.sdk.bases.xcom import BaseXCom 

31from airflow.sdk.definitions._internal.abstractoperator import ( 

32 DEFAULT_EXECUTOR, 

33 DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST, 

34 DEFAULT_OWNER, 

35 DEFAULT_POOL_NAME, 

36 DEFAULT_POOL_SLOTS, 

37 DEFAULT_PRIORITY_WEIGHT, 

38 DEFAULT_QUEUE, 

39 DEFAULT_RETRIES, 

40 DEFAULT_RETRY_DELAY, 

41 DEFAULT_TRIGGER_RULE, 

42 DEFAULT_WAIT_FOR_PAST_DEPENDS_BEFORE_SKIPPING, 

43 DEFAULT_WEIGHT_RULE, 

44 AbstractOperator, 

45 NotMapped, 

46 TaskStateChangeCallbackAttrType, 

47) 

48from airflow.sdk.definitions._internal.expandinput import ( 

49 DictOfListsExpandInput, 

50 ListOfDictsExpandInput, 

51 is_mappable, 

52) 

53from airflow.sdk.definitions._internal.types import NOTSET 

54from airflow.serialization.enums import DagAttributeTypes 

55from airflow.task.priority_strategy import PriorityWeightStrategy, validate_and_load_priority_weight_strategy 

56 

57if TYPE_CHECKING: 

58 import datetime 

59 

60 import jinja2 # Slow import. 

61 import pendulum 

62 

63 from airflow.models.expandinput import ( 

64 OperatorExpandArgument, 

65 OperatorExpandKwargsArgument, 

66 ) 

67 from airflow.sdk import DAG, BaseOperator, BaseOperatorLink, Context, TaskGroup, TriggerRule, XComArg 

68 from airflow.sdk.definitions._internal.expandinput import ExpandInput 

69 from airflow.sdk.definitions.operator_resources import Resources 

70 from airflow.sdk.definitions.param import ParamsDict 

71 from airflow.triggers.base import StartTriggerArgs 

72 

73ValidationSource = Literal["expand"] | Literal["partial"] 

74 

75 

76def validate_mapping_kwargs(op: type[BaseOperator], func: ValidationSource, value: dict[str, Any]) -> None: 

77 # use a dict so order of args is same as code order 

78 unknown_args = value.copy() 

79 for klass in op.mro(): 

80 init = klass.__init__ # type: ignore[misc] 

81 try: 

82 param_names = init._BaseOperatorMeta__param_names 

83 except AttributeError: 

84 continue 

85 for name in param_names: 

86 value = unknown_args.pop(name, NOTSET) 

87 if func != "expand": 

88 continue 

89 if value is NOTSET: 

90 continue 

91 if is_mappable(value): 

92 continue 

93 type_name = type(value).__name__ 

94 error = f"{op.__name__}.expand() got an unexpected type {type_name!r} for keyword argument {name}" 

95 raise ValueError(error) 

96 if not unknown_args: 

97 return # If we have no args left to check: stop looking at the MRO chain. 

98 

99 if len(unknown_args) == 1: 

100 error = f"an unexpected keyword argument {unknown_args.popitem()[0]!r}" 

101 else: 

102 names = ", ".join(repr(n) for n in unknown_args) 

103 error = f"unexpected keyword arguments {names}" 

104 raise TypeError(f"{op.__name__}.{func}() got {error}") 

105 

106 

107def _is_container(obj: Any) -> bool: 

108 """Test if an object is a container (iterable) but not a string.""" 

109 if isinstance(obj, Proxy): 

110 # Proxy of any object is considered a container because it implements __iter__ 

111 # to forward the call to the lazily initialized object 

112 # Unwrap Proxy before checking __iter__ to evaluate the proxied object 

113 obj = obj.__wrapped__ 

114 return hasattr(obj, "__iter__") and not isinstance(obj, str) 

115 

116 

117def ensure_xcomarg_return_value(arg: Any) -> None: 

118 from airflow.sdk.definitions.xcom_arg import XComArg 

119 

120 if isinstance(arg, XComArg): 

121 for operator, key in arg.iter_references(): 

122 if key != BaseXCom.XCOM_RETURN_KEY: 

123 raise ValueError(f"cannot map over XCom with custom key {key!r} from {operator}") 

124 elif not _is_container(arg): 

125 return 

126 elif isinstance(arg, Mapping): 

127 for v in arg.values(): 

128 ensure_xcomarg_return_value(v) 

129 elif isinstance(arg, Iterable): 

130 for v in arg: 

131 ensure_xcomarg_return_value(v) 

132 

133 

134def is_mappable_value(value: Any) -> TypeGuard[Collection]: 

135 """ 

136 Whether a value can be used for task mapping. 

137 

138 We only allow collections with guaranteed ordering, but exclude character 

139 sequences since that's usually not what users would expect to be mappable. 

140 

141 :meta private: 

142 """ 

143 if not isinstance(value, (Sequence, dict)): 

144 return False 

145 if isinstance(value, (bytearray, bytes, str)): 

146 return False 

147 return True 

148 

149 

150def prevent_duplicates(kwargs1: dict[str, Any], kwargs2: Mapping[str, Any], *, fail_reason: str) -> None: 

151 """ 

152 Ensure *kwargs1* and *kwargs2* do not contain common keys. 

153 

154 :raises TypeError: If common keys are found. 

155 """ 

156 duplicated_keys = set(kwargs1).intersection(kwargs2) 

157 if not duplicated_keys: 

158 return 

159 if len(duplicated_keys) == 1: 

160 raise TypeError(f"{fail_reason} argument: {duplicated_keys.pop()}") 

161 duplicated_keys_display = ", ".join(sorted(duplicated_keys)) 

162 raise TypeError(f"{fail_reason} arguments: {duplicated_keys_display}") 

163 

164 

165@attrs.define(kw_only=True, repr=False) 

166class OperatorPartial: 

167 """ 

168 An "intermediate state" returned by ``BaseOperator.partial()``. 

169 

170 This only exists at Dag-parsing time; the only intended usage is for the 

171 user to call ``.expand()`` on it at some point (usually in a method chain) to 

172 create a ``MappedOperator`` to add into the Dag. 

173 """ 

174 

175 operator_class: type[BaseOperator] 

176 kwargs: dict[str, Any] 

177 params: ParamsDict | dict 

178 

179 _expand_called: bool = False # Set when expand() is called to ease user debugging. 

180 

181 def __attrs_post_init__(self): 

182 validate_mapping_kwargs(self.operator_class, "partial", self.kwargs) 

183 

184 def __repr__(self) -> str: 

185 args = ", ".join(f"{k}={v!r}" for k, v in self.kwargs.items()) 

186 return f"{self.operator_class.__name__}.partial({args})" 

187 

188 def __del__(self): 

189 if not self._expand_called: 

190 try: 

191 task_id = repr(self.kwargs["task_id"]) 

192 except KeyError: 

193 task_id = f"at {hex(id(self))}" 

194 warnings.warn(f"Task {task_id} was never mapped!", category=UserWarning, stacklevel=1) 

195 

196 def expand(self, **mapped_kwargs: OperatorExpandArgument) -> MappedOperator: 

197 if not mapped_kwargs: 

198 raise TypeError("no arguments to expand against") 

199 validate_mapping_kwargs(self.operator_class, "expand", mapped_kwargs) 

200 prevent_duplicates(self.kwargs, mapped_kwargs, fail_reason="unmappable or already specified") 

201 # Since the input is already checked at parse time, we can set strict 

202 # to False to skip the checks on execution. 

203 return self._expand(DictOfListsExpandInput(mapped_kwargs), strict=False) 

204 

205 def expand_kwargs(self, kwargs: OperatorExpandKwargsArgument, *, strict: bool = True) -> MappedOperator: 

206 from airflow.sdk.definitions.xcom_arg import XComArg 

207 

208 if isinstance(kwargs, Sequence): 

209 for item in kwargs: 

210 if not isinstance(item, (XComArg, Mapping)): 

211 raise TypeError(f"expected XComArg or list[dict], not {type(kwargs).__name__}") 

212 elif not isinstance(kwargs, XComArg): 

213 raise TypeError(f"expected XComArg or list[dict], not {type(kwargs).__name__}") 

214 return self._expand(ListOfDictsExpandInput(kwargs), strict=strict) 

215 

216 def _expand(self, expand_input: ExpandInput, *, strict: bool) -> MappedOperator: 

217 from airflow.providers.standard.operators.empty import EmptyOperator 

218 from airflow.providers.standard.utils.skipmixin import SkipMixin 

219 from airflow.sdk import BaseSensorOperator 

220 

221 self._expand_called = True 

222 ensure_xcomarg_return_value(expand_input.value) 

223 

224 partial_kwargs = self.kwargs.copy() 

225 task_id = partial_kwargs.pop("task_id") 

226 dag = partial_kwargs.pop("dag") 

227 task_group = partial_kwargs.pop("task_group") 

228 start_date = partial_kwargs.pop("start_date", None) 

229 end_date = partial_kwargs.pop("end_date", None) 

230 

231 try: 

232 operator_name = self.operator_class.custom_operator_name # type: ignore 

233 except AttributeError: 

234 operator_name = self.operator_class.__name__ 

235 

236 op = MappedOperator( 

237 operator_class=self.operator_class, 

238 expand_input=expand_input, 

239 partial_kwargs=partial_kwargs, 

240 task_id=task_id, 

241 params=self.params, 

242 operator_extra_links=self.operator_class.operator_extra_links, 

243 template_ext=self.operator_class.template_ext, 

244 template_fields=self.operator_class.template_fields, 

245 template_fields_renderers=self.operator_class.template_fields_renderers, 

246 ui_color=self.operator_class.ui_color, 

247 ui_fgcolor=self.operator_class.ui_fgcolor, 

248 is_empty=issubclass(self.operator_class, EmptyOperator), 

249 is_sensor=issubclass(self.operator_class, BaseSensorOperator), 

250 can_skip_downstream=issubclass(self.operator_class, SkipMixin), 

251 task_module=self.operator_class.__module__, 

252 task_type=self.operator_class.__name__, 

253 operator_name=operator_name, 

254 dag=dag, 

255 task_group=task_group, 

256 start_date=start_date, 

257 end_date=end_date, 

258 disallow_kwargs_override=strict, 

259 # For classic operators, this points to expand_input because kwargs 

260 # to BaseOperator.expand() contribute to operator arguments. 

261 expand_input_attr="expand_input", 

262 # TODO: Move these to task SDK's BaseOperator and remove getattr 

263 start_trigger_args=getattr(self.operator_class, "start_trigger_args", None), 

264 start_from_trigger=bool(getattr(self.operator_class, "start_from_trigger", False)), 

265 ) 

266 return op 

267 

268 

269@attrs.define( 

270 kw_only=True, 

271 # Disable custom __getstate__ and __setstate__ generation since it interacts 

272 # badly with Airflow's Dag serialization and pickling. When a mapped task is 

273 # deserialized, subclasses are coerced into MappedOperator, but when it goes 

274 # through Dag pickling, all attributes defined in the subclasses are dropped 

275 # by attrs's custom state management. Since attrs does not do anything too 

276 # special here (the logic is only important for slots=True), we use Python's 

277 # built-in implementation, which works (as proven by good old BaseOperator). 

278 getstate_setstate=False, 

279) 

280class MappedOperator(AbstractOperator): 

281 """Object representing a mapped operator in a Dag.""" 

282 

283 operator_class: type[BaseOperator] 

284 

285 _is_mapped: bool = attrs.field(init=False, default=True) 

286 

287 expand_input: ExpandInput 

288 partial_kwargs: dict[str, Any] 

289 

290 # Needed for serialization. 

291 task_id: str 

292 params: ParamsDict | dict 

293 operator_extra_links: Collection[BaseOperatorLink] 

294 template_ext: Sequence[str] 

295 template_fields: Collection[str] 

296 template_fields_renderers: dict[str, str] 

297 ui_color: str 

298 ui_fgcolor: str 

299 _is_empty: bool = attrs.field(alias="is_empty") 

300 _can_skip_downstream: bool = attrs.field(alias="can_skip_downstream") 

301 _is_sensor: bool = attrs.field(alias="is_sensor", default=False) 

302 _task_module: str 

303 task_type: str 

304 _operator_name: str 

305 start_trigger_args: StartTriggerArgs | None 

306 start_from_trigger: bool 

307 _needs_expansion: bool = True 

308 

309 dag: DAG | None 

310 task_group: TaskGroup | None 

311 start_date: pendulum.DateTime | None 

312 end_date: pendulum.DateTime | None 

313 upstream_task_ids: set[str] = attrs.field(factory=set, init=False) 

314 downstream_task_ids: set[str] = attrs.field(factory=set, init=False) 

315 

316 _disallow_kwargs_override: bool 

317 """Whether execution fails if ``expand_input`` has duplicates to ``partial_kwargs``. 

318 

319 If *False*, values from ``expand_input`` under duplicate keys override those 

320 under corresponding keys in ``partial_kwargs``. 

321 """ 

322 

323 _expand_input_attr: str 

324 """Where to get kwargs to calculate expansion length against. 

325 

326 This should be a name to call ``getattr()`` on. 

327 """ 

328 

329 HIDE_ATTRS_FROM_UI: ClassVar[frozenset[str]] = AbstractOperator.HIDE_ATTRS_FROM_UI | frozenset( 

330 ("parse_time_mapped_ti_count", "operator_class", "start_trigger_args", "start_from_trigger") 

331 ) 

332 

333 def __hash__(self): 

334 return id(self) 

335 

336 def __repr__(self): 

337 return f"<Mapped({self.task_type}): {self.task_id}>" 

338 

339 def __attrs_post_init__(self): 

340 from airflow.sdk.definitions.xcom_arg import XComArg 

341 

342 if self.get_closest_mapped_task_group() is not None: 

343 raise NotImplementedError("operator expansion in an expanded task group is not yet supported") 

344 

345 if self.task_group: 

346 self.task_group.add(self) 

347 if self.dag: 

348 self.dag.add_task(self) 

349 XComArg.apply_upstream_relationship(self, self._get_specified_expand_input().value) 

350 for k, v in self.partial_kwargs.items(): 

351 if k in self.template_fields: 

352 XComArg.apply_upstream_relationship(self, v) 

353 

354 @methodtools.lru_cache(maxsize=None) 

355 @classmethod 

356 def get_serialized_fields(cls): 

357 # Not using 'cls' here since we only want to serialize base fields. 

358 return (frozenset(attrs.fields_dict(MappedOperator))) - { 

359 "_is_empty", 

360 "_can_skip_downstream", 

361 "dag", 

362 "deps", 

363 "expand_input", # This is needed to be able to accept XComArg. 

364 "task_group", 

365 "upstream_task_ids", 

366 "_is_setup", 

367 "_is_teardown", 

368 "_on_failure_fail_dagrun", 

369 "operator_class", 

370 "_needs_expansion", 

371 "partial_kwargs", 

372 "operator_extra_links", 

373 } 

374 

375 @property 

376 def operator_name(self) -> str: 

377 return self._operator_name 

378 

379 @property 

380 def roots(self) -> Sequence[AbstractOperator]: 

381 """Implementing DAGNode.""" 

382 return [self] 

383 

384 @property 

385 def leaves(self) -> Sequence[AbstractOperator]: 

386 """Implementing DAGNode.""" 

387 return [self] 

388 

389 @property 

390 def task_display_name(self) -> str: 

391 return self.partial_kwargs.get("task_display_name") or self.task_id 

392 

393 @property 

394 def owner(self) -> str: 

395 return self.partial_kwargs.get("owner", DEFAULT_OWNER) 

396 

397 @owner.setter 

398 def owner(self, value: str) -> None: 

399 self.partial_kwargs["owner"] = value 

400 

401 @property 

402 def email(self) -> None | str | Iterable[str]: 

403 return self.partial_kwargs.get("email") 

404 

405 @property 

406 def email_on_failure(self) -> bool: 

407 return self.partial_kwargs.get("email_on_failure", True) 

408 

409 @property 

410 def email_on_retry(self) -> bool: 

411 return self.partial_kwargs.get("email_on_retry", True) 

412 

413 @property 

414 def map_index_template(self) -> None | str: 

415 return self.partial_kwargs.get("map_index_template") 

416 

417 @map_index_template.setter 

418 def map_index_template(self, value: str | None) -> None: 

419 self.partial_kwargs["map_index_template"] = value 

420 

421 @property 

422 def trigger_rule(self) -> TriggerRule: 

423 return self.partial_kwargs.get("trigger_rule", DEFAULT_TRIGGER_RULE) 

424 

425 @trigger_rule.setter 

426 def trigger_rule(self, value): 

427 self.partial_kwargs["trigger_rule"] = value 

428 

429 @property 

430 def is_setup(self) -> bool: 

431 return bool(self.partial_kwargs.get("is_setup")) 

432 

433 @is_setup.setter 

434 def is_setup(self, value: bool) -> None: 

435 self.partial_kwargs["is_setup"] = value 

436 

437 @property 

438 def is_teardown(self) -> bool: 

439 return bool(self.partial_kwargs.get("is_teardown")) 

440 

441 @is_teardown.setter 

442 def is_teardown(self, value: bool) -> None: 

443 self.partial_kwargs["is_teardown"] = value 

444 

445 @property 

446 def depends_on_past(self) -> bool: 

447 return bool(self.partial_kwargs.get("depends_on_past")) 

448 

449 @depends_on_past.setter 

450 def depends_on_past(self, value: bool) -> None: 

451 self.partial_kwargs["depends_on_past"] = value 

452 

453 @property 

454 def ignore_first_depends_on_past(self) -> bool: 

455 value = self.partial_kwargs.get("ignore_first_depends_on_past", DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST) 

456 return bool(value) 

457 

458 @ignore_first_depends_on_past.setter 

459 def ignore_first_depends_on_past(self, value: bool) -> None: 

460 self.partial_kwargs["ignore_first_depends_on_past"] = value 

461 

462 @property 

463 def wait_for_past_depends_before_skipping(self) -> bool: 

464 value = self.partial_kwargs.get( 

465 "wait_for_past_depends_before_skipping", DEFAULT_WAIT_FOR_PAST_DEPENDS_BEFORE_SKIPPING 

466 ) 

467 return bool(value) 

468 

469 @wait_for_past_depends_before_skipping.setter 

470 def wait_for_past_depends_before_skipping(self, value: bool) -> None: 

471 self.partial_kwargs["wait_for_past_depends_before_skipping"] = value 

472 

473 @property 

474 def wait_for_downstream(self) -> bool: 

475 return bool(self.partial_kwargs.get("wait_for_downstream")) 

476 

477 @wait_for_downstream.setter 

478 def wait_for_downstream(self, value: bool) -> None: 

479 self.partial_kwargs["wait_for_downstream"] = value 

480 

481 @property 

482 def retries(self) -> int: 

483 return self.partial_kwargs.get("retries", DEFAULT_RETRIES) 

484 

485 @retries.setter 

486 def retries(self, value: int) -> None: 

487 self.partial_kwargs["retries"] = value 

488 

489 @property 

490 def queue(self) -> str: 

491 return self.partial_kwargs.get("queue", DEFAULT_QUEUE) 

492 

493 @queue.setter 

494 def queue(self, value: str) -> None: 

495 self.partial_kwargs["queue"] = value 

496 

497 @property 

498 def pool(self) -> str: 

499 return self.partial_kwargs.get("pool", DEFAULT_POOL_NAME) 

500 

501 @pool.setter 

502 def pool(self, value: str) -> None: 

503 self.partial_kwargs["pool"] = value 

504 

505 @property 

506 def pool_slots(self) -> int: 

507 return self.partial_kwargs.get("pool_slots", DEFAULT_POOL_SLOTS) 

508 

509 @pool_slots.setter 

510 def pool_slots(self, value: int) -> None: 

511 self.partial_kwargs["pool_slots"] = value 

512 

513 @property 

514 def execution_timeout(self) -> datetime.timedelta | None: 

515 return self.partial_kwargs.get("execution_timeout") 

516 

517 @execution_timeout.setter 

518 def execution_timeout(self, value: datetime.timedelta | None) -> None: 

519 self.partial_kwargs["execution_timeout"] = value 

520 

521 @property 

522 def max_retry_delay(self) -> datetime.timedelta | None: 

523 return self.partial_kwargs.get("max_retry_delay") 

524 

525 @max_retry_delay.setter 

526 def max_retry_delay(self, value: datetime.timedelta | None) -> None: 

527 self.partial_kwargs["max_retry_delay"] = value 

528 

529 @property 

530 def retry_delay(self) -> datetime.timedelta: 

531 return self.partial_kwargs.get("retry_delay", DEFAULT_RETRY_DELAY) 

532 

533 @retry_delay.setter 

534 def retry_delay(self, value: datetime.timedelta) -> None: 

535 self.partial_kwargs["retry_delay"] = value 

536 

537 @property 

538 def retry_exponential_backoff(self) -> float: 

539 value = self.partial_kwargs.get("retry_exponential_backoff", 0) 

540 if value is True: 

541 return 2.0 

542 if value is False: 

543 return 0.0 

544 return float(value) 

545 

546 @retry_exponential_backoff.setter 

547 def retry_exponential_backoff(self, value: float) -> None: 

548 self.partial_kwargs["retry_exponential_backoff"] = value 

549 

550 @property 

551 def priority_weight(self) -> int: 

552 return self.partial_kwargs.get("priority_weight", DEFAULT_PRIORITY_WEIGHT) 

553 

554 @priority_weight.setter 

555 def priority_weight(self, value: int) -> None: 

556 self.partial_kwargs["priority_weight"] = value 

557 

558 @property 

559 def weight_rule(self) -> PriorityWeightStrategy: 

560 return validate_and_load_priority_weight_strategy( 

561 self.partial_kwargs.get("weight_rule", DEFAULT_WEIGHT_RULE) 

562 ) 

563 

564 @weight_rule.setter 

565 def weight_rule(self, value: str | PriorityWeightStrategy) -> None: 

566 self.partial_kwargs["weight_rule"] = validate_and_load_priority_weight_strategy(value) 

567 

568 @property 

569 def max_active_tis_per_dag(self) -> int | None: 

570 return self.partial_kwargs.get("max_active_tis_per_dag") 

571 

572 @max_active_tis_per_dag.setter 

573 def max_active_tis_per_dag(self, value: int | None) -> None: 

574 self.partial_kwargs["max_active_tis_per_dag"] = value 

575 

576 @property 

577 def max_active_tis_per_dagrun(self) -> int | None: 

578 return self.partial_kwargs.get("max_active_tis_per_dagrun") 

579 

580 @max_active_tis_per_dagrun.setter 

581 def max_active_tis_per_dagrun(self, value: int | None) -> None: 

582 self.partial_kwargs["max_active_tis_per_dagrun"] = value 

583 

584 @property 

585 def resources(self) -> Resources | None: 

586 return self.partial_kwargs.get("resources") 

587 

588 @property 

589 def on_execute_callback(self) -> TaskStateChangeCallbackAttrType: 

590 return self.partial_kwargs.get("on_execute_callback") or [] 

591 

592 @on_execute_callback.setter 

593 def on_execute_callback(self, value: TaskStateChangeCallbackAttrType) -> None: 

594 self.partial_kwargs["on_execute_callback"] = value or [] 

595 

596 @property 

597 def on_failure_callback(self) -> TaskStateChangeCallbackAttrType: 

598 return self.partial_kwargs.get("on_failure_callback") or [] 

599 

600 @on_failure_callback.setter 

601 def on_failure_callback(self, value: TaskStateChangeCallbackAttrType) -> None: 

602 self.partial_kwargs["on_failure_callback"] = value or [] 

603 

604 @property 

605 def on_retry_callback(self) -> TaskStateChangeCallbackAttrType: 

606 return self.partial_kwargs.get("on_retry_callback") or [] 

607 

608 @on_retry_callback.setter 

609 def on_retry_callback(self, value: TaskStateChangeCallbackAttrType) -> None: 

610 self.partial_kwargs["on_retry_callback"] = value or [] 

611 

612 @property 

613 def on_success_callback(self) -> TaskStateChangeCallbackAttrType: 

614 return self.partial_kwargs.get("on_success_callback") or [] 

615 

616 @on_success_callback.setter 

617 def on_success_callback(self, value: TaskStateChangeCallbackAttrType) -> None: 

618 self.partial_kwargs["on_success_callback"] = value or [] 

619 

620 @property 

621 def on_skipped_callback(self) -> TaskStateChangeCallbackAttrType: 

622 return self.partial_kwargs.get("on_skipped_callback") or [] 

623 

624 @on_skipped_callback.setter 

625 def on_skipped_callback(self, value: TaskStateChangeCallbackAttrType) -> None: 

626 self.partial_kwargs["on_skipped_callback"] = value or [] 

627 

628 @property 

629 def has_on_execute_callback(self) -> bool: 

630 return bool(self.on_execute_callback) 

631 

632 @property 

633 def has_on_failure_callback(self) -> bool: 

634 return bool(self.on_failure_callback) 

635 

636 @property 

637 def has_on_retry_callback(self) -> bool: 

638 return bool(self.on_retry_callback) 

639 

640 @property 

641 def has_on_success_callback(self) -> bool: 

642 return bool(self.on_success_callback) 

643 

644 @property 

645 def has_on_skipped_callback(self) -> bool: 

646 return bool(self.on_skipped_callback) 

647 

648 @property 

649 def run_as_user(self) -> str | None: 

650 return self.partial_kwargs.get("run_as_user") 

651 

652 @property 

653 def executor(self) -> str | None: 

654 return self.partial_kwargs.get("executor", DEFAULT_EXECUTOR) 

655 

656 @property 

657 def executor_config(self) -> dict: 

658 return self.partial_kwargs.get("executor_config", {}) 

659 

660 @property 

661 def inlets(self) -> list[Any]: 

662 return self.partial_kwargs.get("inlets", []) 

663 

664 @inlets.setter 

665 def inlets(self, value: list[Any]) -> None: 

666 self.partial_kwargs["inlets"] = value 

667 

668 @property 

669 def outlets(self) -> list[Any]: 

670 return self.partial_kwargs.get("outlets", []) 

671 

672 @outlets.setter 

673 def outlets(self, value: list[Any]) -> None: 

674 self.partial_kwargs["outlets"] = value 

675 

676 @property 

677 def doc(self) -> str | None: 

678 return self.partial_kwargs.get("doc") 

679 

680 @property 

681 def doc_md(self) -> str | None: 

682 return self.partial_kwargs.get("doc_md") 

683 

684 @property 

685 def doc_json(self) -> str | None: 

686 return self.partial_kwargs.get("doc_json") 

687 

688 @property 

689 def doc_yaml(self) -> str | None: 

690 return self.partial_kwargs.get("doc_yaml") 

691 

692 @property 

693 def doc_rst(self) -> str | None: 

694 return self.partial_kwargs.get("doc_rst") 

695 

696 @property 

697 def allow_nested_operators(self) -> bool: 

698 return bool(self.partial_kwargs.get("allow_nested_operators")) 

699 

700 def get_dag(self) -> DAG | None: 

701 """Implement Operator.""" 

702 return self.dag 

703 

704 @property 

705 def output(self) -> XComArg: 

706 """Return reference to XCom pushed by current operator.""" 

707 from airflow.sdk.definitions.xcom_arg import XComArg 

708 

709 return XComArg(operator=self) 

710 

711 def serialize_for_task_group(self) -> tuple[DagAttributeTypes, Any]: 

712 """Implement DAGNode.""" 

713 return DagAttributeTypes.OP, self.task_id 

714 

715 def _expand_mapped_kwargs(self, context: Mapping[str, Any]) -> tuple[Mapping[str, Any], set[int]]: 

716 """ 

717 Get the kwargs to create the unmapped operator. 

718 

719 This exists because taskflow operators expand against op_kwargs, not the 

720 entire operator kwargs dict. 

721 """ 

722 return self._get_specified_expand_input().resolve(context) 

723 

724 def _get_unmap_kwargs(self, mapped_kwargs: Mapping[str, Any], *, strict: bool) -> dict[str, Any]: 

725 """ 

726 Get init kwargs to unmap the underlying operator class. 

727 

728 :param mapped_kwargs: The dict returned by ``_expand_mapped_kwargs``. 

729 """ 

730 if strict: 

731 prevent_duplicates( 

732 self.partial_kwargs, 

733 mapped_kwargs, 

734 fail_reason="unmappable or already specified", 

735 ) 

736 

737 # If params appears in the mapped kwargs, we need to merge it into the 

738 # partial params, overriding existing keys. 

739 params = copy.copy(self.params) 

740 with contextlib.suppress(KeyError): 

741 params.update(mapped_kwargs["params"]) 

742 

743 # Ordering is significant; mapped kwargs should override partial ones, 

744 # and the specially handled params should be respected. 

745 return { 

746 "task_id": self.task_id, 

747 "dag": self.dag, 

748 "task_group": self.task_group, 

749 "start_date": self.start_date, 

750 "end_date": self.end_date, 

751 **self.partial_kwargs, 

752 **mapped_kwargs, 

753 "params": params, 

754 } 

755 

756 def unmap(self, resolve: None | Mapping[str, Any]) -> BaseOperator: 

757 """ 

758 Get the "normal" Operator after applying the current mapping. 

759 

760 :meta private: 

761 """ 

762 if isinstance(resolve, Mapping): 

763 kwargs = resolve 

764 elif resolve is not None: 

765 kwargs, _ = self._expand_mapped_kwargs(*resolve) 

766 else: 

767 raise RuntimeError("cannot unmap a non-serialized operator without context") 

768 kwargs = self._get_unmap_kwargs(kwargs, strict=self._disallow_kwargs_override) 

769 is_setup = kwargs.pop("is_setup", False) 

770 is_teardown = kwargs.pop("is_teardown", False) 

771 on_failure_fail_dagrun = kwargs.pop("on_failure_fail_dagrun", False) 

772 kwargs["task_id"] = self.task_id 

773 op = self.operator_class(**kwargs, _airflow_from_mapped=True) 

774 op.is_setup = is_setup 

775 op.is_teardown = is_teardown 

776 op.on_failure_fail_dagrun = on_failure_fail_dagrun 

777 op.downstream_task_ids = self.downstream_task_ids 

778 op.upstream_task_ids = self.upstream_task_ids 

779 return op 

780 

781 def _get_specified_expand_input(self) -> ExpandInput: 

782 """Input received from the expand call on the operator.""" 

783 return getattr(self, self._expand_input_attr) 

784 

785 def prepare_for_execution(self) -> MappedOperator: 

786 # Since a mapped operator cannot be used for execution, and an unmapped 

787 # BaseOperator needs to be created later (see render_template_fields), 

788 # we don't need to create a copy of the MappedOperator here. 

789 return self 

790 

791 # TODO (GH-52141): Do we need this in the SDK? 

792 def iter_mapped_dependencies(self) -> Iterator[AbstractOperator]: 

793 """Upstream dependencies that provide XComs used by this task for task mapping.""" 

794 from airflow.sdk.definitions.xcom_arg import XComArg 

795 

796 for operator, _ in XComArg.iter_xcom_references(self._get_specified_expand_input()): 

797 yield operator 

798 

799 @methodtools.lru_cache(maxsize=None) 

800 def get_parse_time_mapped_ti_count(self) -> int: 

801 current_count = self._get_specified_expand_input().get_parse_time_mapped_ti_count() 

802 try: 

803 # The use of `methodtools` interferes with the zero-arg super 

804 parent_count = super(MappedOperator, self).get_parse_time_mapped_ti_count() # noqa: UP008 

805 except NotMapped: 

806 return current_count 

807 return parent_count * current_count 

808 

809 def render_template_fields( 

810 self, 

811 context: Context, 

812 jinja_env: jinja2.Environment | None = None, 

813 ) -> None: 

814 """ 

815 Template all attributes listed in *self.template_fields*. 

816 

817 This updates *context* to reference the map-expanded task and relevant 

818 information, without modifying the mapped operator. The expanded task 

819 in *context* is then rendered in-place. 

820 

821 :param context: Context dict with values to apply on content. 

822 :param jinja_env: Jinja environment to use for rendering. 

823 """ 

824 from airflow.sdk.execution_time.context import context_update_for_unmapped 

825 

826 if not jinja_env: 

827 jinja_env = self.get_template_env() 

828 

829 mapped_kwargs, seen_oids = self._expand_mapped_kwargs(context) 

830 unmapped_task = self.unmap(mapped_kwargs) 

831 context_update_for_unmapped(context, unmapped_task) 

832 

833 # Since the operators that extend `BaseOperator` are not subclasses of 

834 # `MappedOperator`, we need to call `_do_render_template_fields` from 

835 # the unmapped task in order to call the operator method when we override 

836 # it to customize the parsing of nested fields. 

837 unmapped_task._do_render_template_fields( 

838 parent=unmapped_task, 

839 template_fields=self.template_fields, 

840 context=context, 

841 jinja_env=jinja_env, 

842 seen_oids=seen_oids, 

843 ) 

844 

845 def expand_start_trigger_args(self, *, context: Context) -> StartTriggerArgs | None: 

846 """ 

847 Get the kwargs to create the unmapped start_trigger_args. 

848 

849 This method is for allowing mapped operator to start execution from triggerer. 

850 """ 

851 from airflow.triggers.base import StartTriggerArgs 

852 

853 if not self.start_trigger_args: 

854 return None 

855 

856 mapped_kwargs, _ = self._expand_mapped_kwargs(context) 

857 if self._disallow_kwargs_override: 

858 prevent_duplicates( 

859 self.partial_kwargs, 

860 mapped_kwargs, 

861 fail_reason="unmappable or already specified", 

862 ) 

863 

864 # Ordering is significant; mapped kwargs should override partial ones. 

865 trigger_kwargs = mapped_kwargs.get( 

866 "trigger_kwargs", 

867 self.partial_kwargs.get("trigger_kwargs", self.start_trigger_args.trigger_kwargs), 

868 ) 

869 next_kwargs = mapped_kwargs.get( 

870 "next_kwargs", 

871 self.partial_kwargs.get("next_kwargs", self.start_trigger_args.next_kwargs), 

872 ) 

873 timeout = mapped_kwargs.get( 

874 "trigger_timeout", self.partial_kwargs.get("trigger_timeout", self.start_trigger_args.timeout) 

875 ) 

876 return StartTriggerArgs( 

877 trigger_cls=self.start_trigger_args.trigger_cls, 

878 trigger_kwargs=trigger_kwargs, 

879 next_method=self.start_trigger_args.next_method, 

880 next_kwargs=next_kwargs, 

881 timeout=timeout, 

882 )