Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/airflow/sdk/definitions/mappedoperator.py: 49%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

492 statements  

1# 

2# Licensed to the Apache Software Foundation (ASF) under one 

3# or more contributor license agreements. See the NOTICE file 

4# distributed with this work for additional information 

5# regarding copyright ownership. The ASF licenses this file 

6# to you under the Apache License, Version 2.0 (the 

7# "License"); you may not use this file except in compliance 

8# with the License. You may obtain a copy of the License at 

9# 

10# http://www.apache.org/licenses/LICENSE-2.0 

11# 

12# Unless required by applicable law or agreed to in writing, 

13# software distributed under the License is distributed on an 

14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 

15# KIND, either express or implied. See the License for the 

16# specific language governing permissions and limitations 

17# under the License. 

18from __future__ import annotations 

19 

20import contextlib 

21import copy 

22import warnings 

23from collections.abc import Collection, Iterable, Iterator, Mapping, Sequence 

24from typing import TYPE_CHECKING, Any, ClassVar, Literal, TypeGuard 

25 

26import attrs 

27import methodtools 

28from lazy_object_proxy import Proxy 

29 

30from airflow.sdk.bases.xcom import BaseXCom 

31from airflow.sdk.definitions._internal.abstractoperator import ( 

32 DEFAULT_EXECUTOR, 

33 DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST, 

34 DEFAULT_OWNER, 

35 DEFAULT_POOL_NAME, 

36 DEFAULT_POOL_SLOTS, 

37 DEFAULT_PRIORITY_WEIGHT, 

38 DEFAULT_QUEUE, 

39 DEFAULT_RETRIES, 

40 DEFAULT_RETRY_DELAY, 

41 DEFAULT_TRIGGER_RULE, 

42 DEFAULT_WAIT_FOR_PAST_DEPENDS_BEFORE_SKIPPING, 

43 DEFAULT_WEIGHT_RULE, 

44 AbstractOperator, 

45 TaskStateChangeCallbackAttrType, 

46) 

47from airflow.sdk.definitions._internal.expandinput import ( 

48 DictOfListsExpandInput, 

49 ListOfDictsExpandInput, 

50 is_mappable, 

51) 

52from airflow.sdk.definitions._internal.types import NOTSET 

53from airflow.serialization.enums import DagAttributeTypes 

54from airflow.task.priority_strategy import PriorityWeightStrategy, validate_and_load_priority_weight_strategy 

55 

56if TYPE_CHECKING: 

57 import datetime 

58 

59 import jinja2 # Slow import. 

60 import pendulum 

61 

62 from airflow.sdk import DAG, BaseOperator, BaseOperatorLink, Context, TaskGroup, TriggerRule, XComArg 

63 from airflow.sdk.definitions._internal.expandinput import ( 

64 ExpandInput, 

65 OperatorExpandArgument, 

66 OperatorExpandKwargsArgument, 

67 ) 

68 from airflow.sdk.definitions.operator_resources import Resources 

69 from airflow.sdk.definitions.param import ParamsDict 

70 from airflow.triggers.base import StartTriggerArgs 

71 

72ValidationSource = Literal["expand"] | Literal["partial"] 

73 

74 

75def validate_mapping_kwargs(op: type[BaseOperator], func: ValidationSource, value: dict[str, Any]) -> None: 

76 # use a dict so order of args is same as code order 

77 unknown_args = value.copy() 

78 for klass in op.mro(): 

79 init = klass.__init__ # type: ignore[misc] 

80 try: 

81 param_names = init._BaseOperatorMeta__param_names 

82 except AttributeError: 

83 continue 

84 for name in param_names: 

85 value = unknown_args.pop(name, NOTSET) 

86 if func != "expand": 

87 continue 

88 if value is NOTSET: 

89 continue 

90 if is_mappable(value): 

91 continue 

92 type_name = type(value).__name__ 

93 error = f"{op.__name__}.expand() got an unexpected type {type_name!r} for keyword argument {name}" 

94 raise ValueError(error) 

95 if not unknown_args: 

96 return # If we have no args left to check: stop looking at the MRO chain. 

97 

98 if len(unknown_args) == 1: 

99 error = f"an unexpected keyword argument {unknown_args.popitem()[0]!r}" 

100 else: 

101 names = ", ".join(repr(n) for n in unknown_args) 

102 error = f"unexpected keyword arguments {names}" 

103 raise TypeError(f"{op.__name__}.{func}() got {error}") 

104 

105 

106def _is_container(obj: Any) -> bool: 

107 """Test if an object is a container (iterable) but not a string.""" 

108 if isinstance(obj, Proxy): 

109 # Proxy of any object is considered a container because it implements __iter__ 

110 # to forward the call to the lazily initialized object 

111 # Unwrap Proxy before checking __iter__ to evaluate the proxied object 

112 obj = obj.__wrapped__ 

113 return hasattr(obj, "__iter__") and not isinstance(obj, str) 

114 

115 

116def ensure_xcomarg_return_value(arg: Any) -> None: 

117 from airflow.sdk.definitions.xcom_arg import XComArg 

118 

119 if isinstance(arg, XComArg): 

120 for operator, key in arg.iter_references(): 

121 if key != BaseXCom.XCOM_RETURN_KEY: 

122 raise ValueError(f"cannot map over XCom with custom key {key!r} from {operator}") 

123 elif not _is_container(arg): 

124 return 

125 elif isinstance(arg, Mapping): 

126 for v in arg.values(): 

127 ensure_xcomarg_return_value(v) 

128 elif isinstance(arg, Iterable): 

129 for v in arg: 

130 ensure_xcomarg_return_value(v) 

131 

132 

133def is_mappable_value(value: Any) -> TypeGuard[Collection]: 

134 """ 

135 Whether a value can be used for task mapping. 

136 

137 We only allow collections with guaranteed ordering, but exclude character 

138 sequences since that's usually not what users would expect to be mappable. 

139 

140 :meta private: 

141 """ 

142 if not isinstance(value, (Sequence, dict)): 

143 return False 

144 if isinstance(value, (bytearray, bytes, str)): 

145 return False 

146 return True 

147 

148 

149def prevent_duplicates(kwargs1: dict[str, Any], kwargs2: Mapping[str, Any], *, fail_reason: str) -> None: 

150 """ 

151 Ensure *kwargs1* and *kwargs2* do not contain common keys. 

152 

153 :raises TypeError: If common keys are found. 

154 """ 

155 duplicated_keys = set(kwargs1).intersection(kwargs2) 

156 if not duplicated_keys: 

157 return 

158 if len(duplicated_keys) == 1: 

159 raise TypeError(f"{fail_reason} argument: {duplicated_keys.pop()}") 

160 duplicated_keys_display = ", ".join(sorted(duplicated_keys)) 

161 raise TypeError(f"{fail_reason} arguments: {duplicated_keys_display}") 

162 

163 

164@attrs.define(kw_only=True, repr=False) 

165class OperatorPartial: 

166 """ 

167 An "intermediate state" returned by ``BaseOperator.partial()``. 

168 

169 This only exists at Dag-parsing time; the only intended usage is for the 

170 user to call ``.expand()`` on it at some point (usually in a method chain) to 

171 create a ``MappedOperator`` to add into the Dag. 

172 """ 

173 

174 operator_class: type[BaseOperator] 

175 kwargs: dict[str, Any] 

176 params: ParamsDict | dict 

177 

178 _expand_called: bool = False # Set when expand() is called to ease user debugging. 

179 

180 def __attrs_post_init__(self): 

181 validate_mapping_kwargs(self.operator_class, "partial", self.kwargs) 

182 

183 def __repr__(self) -> str: 

184 args = ", ".join(f"{k}={v!r}" for k, v in self.kwargs.items()) 

185 return f"{self.operator_class.__name__}.partial({args})" 

186 

187 def __del__(self): 

188 if not self._expand_called: 

189 try: 

190 task_id = repr(self.kwargs["task_id"]) 

191 except KeyError: 

192 task_id = f"at {hex(id(self))}" 

193 warnings.warn(f"Task {task_id} was never mapped!", category=UserWarning, stacklevel=1) 

194 

195 def expand(self, **mapped_kwargs: OperatorExpandArgument) -> MappedOperator: 

196 if not mapped_kwargs: 

197 raise TypeError("no arguments to expand against") 

198 validate_mapping_kwargs(self.operator_class, "expand", mapped_kwargs) 

199 prevent_duplicates(self.kwargs, mapped_kwargs, fail_reason="unmappable or already specified") 

200 # Since the input is already checked at parse time, we can set strict 

201 # to False to skip the checks on execution. 

202 return self._expand(DictOfListsExpandInput(mapped_kwargs), strict=False) 

203 

204 def expand_kwargs(self, kwargs: OperatorExpandKwargsArgument, *, strict: bool = True) -> MappedOperator: 

205 from airflow.sdk.definitions.xcom_arg import XComArg 

206 

207 if isinstance(kwargs, Sequence): 

208 for item in kwargs: 

209 if not isinstance(item, (XComArg, Mapping)): 

210 raise TypeError(f"expected XComArg or list[dict], not {type(kwargs).__name__}") 

211 elif not isinstance(kwargs, XComArg): 

212 raise TypeError(f"expected XComArg or list[dict], not {type(kwargs).__name__}") 

213 return self._expand(ListOfDictsExpandInput(kwargs), strict=strict) 

214 

215 def _expand(self, expand_input: ExpandInput, *, strict: bool) -> MappedOperator: 

216 from airflow.providers.standard.operators.empty import EmptyOperator 

217 from airflow.providers.standard.utils.skipmixin import SkipMixin 

218 from airflow.sdk import BaseSensorOperator 

219 

220 self._expand_called = True 

221 ensure_xcomarg_return_value(expand_input.value) 

222 

223 partial_kwargs = self.kwargs.copy() 

224 task_id = partial_kwargs.pop("task_id") 

225 dag = partial_kwargs.pop("dag") 

226 task_group = partial_kwargs.pop("task_group") 

227 start_date = partial_kwargs.pop("start_date", None) 

228 end_date = partial_kwargs.pop("end_date", None) 

229 

230 try: 

231 operator_name = self.operator_class.custom_operator_name # type: ignore 

232 except AttributeError: 

233 operator_name = self.operator_class.__name__ 

234 

235 op = MappedOperator( 

236 operator_class=self.operator_class, 

237 expand_input=expand_input, 

238 partial_kwargs=partial_kwargs, 

239 task_id=task_id, 

240 params=self.params, 

241 operator_extra_links=self.operator_class.operator_extra_links, 

242 template_ext=self.operator_class.template_ext, 

243 template_fields=self.operator_class.template_fields, 

244 template_fields_renderers=self.operator_class.template_fields_renderers, 

245 ui_color=self.operator_class.ui_color, 

246 ui_fgcolor=self.operator_class.ui_fgcolor, 

247 is_empty=issubclass(self.operator_class, EmptyOperator), 

248 is_sensor=issubclass(self.operator_class, BaseSensorOperator), 

249 can_skip_downstream=issubclass(self.operator_class, SkipMixin), 

250 task_module=self.operator_class.__module__, 

251 task_type=self.operator_class.__name__, 

252 operator_name=operator_name, 

253 dag=dag, 

254 task_group=task_group, 

255 start_date=start_date, 

256 end_date=end_date, 

257 disallow_kwargs_override=strict, 

258 # For classic operators, this points to expand_input because kwargs 

259 # to BaseOperator.expand() contribute to operator arguments. 

260 expand_input_attr="expand_input", 

261 # TODO: Move these to task SDK's BaseOperator and remove getattr 

262 start_trigger_args=getattr(self.operator_class, "start_trigger_args", None), 

263 start_from_trigger=bool(getattr(self.operator_class, "start_from_trigger", False)), 

264 ) 

265 return op 

266 

267 

268@attrs.define( 

269 kw_only=True, 

270 # Disable custom __getstate__ and __setstate__ generation since it interacts 

271 # badly with Airflow's Dag serialization and pickling. When a mapped task is 

272 # deserialized, subclasses are coerced into MappedOperator, but when it goes 

273 # through Dag pickling, all attributes defined in the subclasses are dropped 

274 # by attrs's custom state management. Since attrs does not do anything too 

275 # special here (the logic is only important for slots=True), we use Python's 

276 # built-in implementation, which works (as proven by good old BaseOperator). 

277 getstate_setstate=False, 

278) 

279class MappedOperator(AbstractOperator): 

280 """Object representing a mapped operator in a Dag.""" 

281 

282 operator_class: type[BaseOperator] 

283 

284 _is_mapped: bool = attrs.field(init=False, default=True) 

285 

286 expand_input: ExpandInput 

287 partial_kwargs: dict[str, Any] 

288 

289 # Needed for serialization. 

290 task_id: str 

291 params: ParamsDict | dict 

292 operator_extra_links: Collection[BaseOperatorLink] 

293 template_ext: Sequence[str] 

294 template_fields: Collection[str] 

295 template_fields_renderers: dict[str, str] 

296 ui_color: str 

297 ui_fgcolor: str 

298 _is_empty: bool = attrs.field(alias="is_empty") 

299 _can_skip_downstream: bool = attrs.field(alias="can_skip_downstream") 

300 _is_sensor: bool = attrs.field(alias="is_sensor", default=False) 

301 _task_module: str 

302 task_type: str 

303 _operator_name: str 

304 start_trigger_args: StartTriggerArgs | None 

305 start_from_trigger: bool 

306 _needs_expansion: bool = True 

307 

308 dag: DAG | None 

309 task_group: TaskGroup | None 

310 start_date: pendulum.DateTime | None 

311 end_date: pendulum.DateTime | None 

312 upstream_task_ids: set[str] = attrs.field(factory=set, init=False) 

313 downstream_task_ids: set[str] = attrs.field(factory=set, init=False) 

314 

315 _disallow_kwargs_override: bool 

316 """Whether execution fails if ``expand_input`` has duplicates to ``partial_kwargs``. 

317 

318 If *False*, values from ``expand_input`` under duplicate keys override those 

319 under corresponding keys in ``partial_kwargs``. 

320 """ 

321 

322 _expand_input_attr: str 

323 """Where to get kwargs to calculate expansion length against. 

324 

325 This should be a name to call ``getattr()`` on. 

326 """ 

327 

328 HIDE_ATTRS_FROM_UI: ClassVar[frozenset[str]] = AbstractOperator.HIDE_ATTRS_FROM_UI | frozenset( 

329 ("parse_time_mapped_ti_count", "operator_class", "start_trigger_args", "start_from_trigger") 

330 ) 

331 

332 def __hash__(self): 

333 return id(self) 

334 

335 def __repr__(self): 

336 return f"<Mapped({self.task_type}): {self.task_id}>" 

337 

338 def __attrs_post_init__(self): 

339 from airflow.sdk.definitions.xcom_arg import XComArg 

340 

341 if self.get_closest_mapped_task_group() is not None: 

342 raise NotImplementedError("operator expansion in an expanded task group is not yet supported") 

343 

344 if self.task_group: 

345 self.task_group.add(self) 

346 if self.dag: 

347 self.dag.add_task(self) 

348 XComArg.apply_upstream_relationship(self, self._get_specified_expand_input().value) 

349 for k, v in self.partial_kwargs.items(): 

350 if k in self.template_fields: 

351 XComArg.apply_upstream_relationship(self, v) 

352 

353 @methodtools.lru_cache(maxsize=None) 

354 @classmethod 

355 def get_serialized_fields(cls): 

356 # Not using 'cls' here since we only want to serialize base fields. 

357 return (frozenset(attrs.fields_dict(MappedOperator))) - { 

358 "_is_empty", 

359 "_can_skip_downstream", 

360 "dag", 

361 "deps", 

362 "expand_input", # This is needed to be able to accept XComArg. 

363 "task_group", 

364 "upstream_task_ids", 

365 "_is_setup", 

366 "_is_teardown", 

367 "_on_failure_fail_dagrun", 

368 "operator_class", 

369 "_needs_expansion", 

370 "partial_kwargs", 

371 "operator_extra_links", 

372 } 

373 

374 @property 

375 def operator_name(self) -> str: 

376 return self._operator_name 

377 

378 @property 

379 def roots(self) -> Sequence[AbstractOperator]: 

380 """Implementing DAGNode.""" 

381 return [self] 

382 

383 @property 

384 def leaves(self) -> Sequence[AbstractOperator]: 

385 """Implementing DAGNode.""" 

386 return [self] 

387 

388 @property 

389 def task_display_name(self) -> str: 

390 return self.partial_kwargs.get("task_display_name") or self.task_id 

391 

392 @property 

393 def owner(self) -> str: 

394 return self.partial_kwargs.get("owner", DEFAULT_OWNER) 

395 

396 @owner.setter 

397 def owner(self, value: str) -> None: 

398 self.partial_kwargs["owner"] = value 

399 

400 @property 

401 def email(self) -> None | str | Iterable[str]: 

402 return self.partial_kwargs.get("email") 

403 

404 @property 

405 def email_on_failure(self) -> bool: 

406 return self.partial_kwargs.get("email_on_failure", True) 

407 

408 @property 

409 def email_on_retry(self) -> bool: 

410 return self.partial_kwargs.get("email_on_retry", True) 

411 

412 @property 

413 def map_index_template(self) -> None | str: 

414 return self.partial_kwargs.get("map_index_template") 

415 

416 @map_index_template.setter 

417 def map_index_template(self, value: str | None) -> None: 

418 self.partial_kwargs["map_index_template"] = value 

419 

420 @property 

421 def trigger_rule(self) -> TriggerRule: 

422 return self.partial_kwargs.get("trigger_rule", DEFAULT_TRIGGER_RULE) 

423 

424 @trigger_rule.setter 

425 def trigger_rule(self, value): 

426 self.partial_kwargs["trigger_rule"] = value 

427 

428 @property 

429 def is_setup(self) -> bool: 

430 return bool(self.partial_kwargs.get("is_setup")) 

431 

432 @is_setup.setter 

433 def is_setup(self, value: bool) -> None: 

434 self.partial_kwargs["is_setup"] = value 

435 

436 @property 

437 def is_teardown(self) -> bool: 

438 return bool(self.partial_kwargs.get("is_teardown")) 

439 

440 @is_teardown.setter 

441 def is_teardown(self, value: bool) -> None: 

442 self.partial_kwargs["is_teardown"] = value 

443 

444 @property 

445 def depends_on_past(self) -> bool: 

446 return bool(self.partial_kwargs.get("depends_on_past")) 

447 

448 @depends_on_past.setter 

449 def depends_on_past(self, value: bool) -> None: 

450 self.partial_kwargs["depends_on_past"] = value 

451 

452 @property 

453 def ignore_first_depends_on_past(self) -> bool: 

454 value = self.partial_kwargs.get("ignore_first_depends_on_past", DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST) 

455 return bool(value) 

456 

457 @ignore_first_depends_on_past.setter 

458 def ignore_first_depends_on_past(self, value: bool) -> None: 

459 self.partial_kwargs["ignore_first_depends_on_past"] = value 

460 

461 @property 

462 def wait_for_past_depends_before_skipping(self) -> bool: 

463 value = self.partial_kwargs.get( 

464 "wait_for_past_depends_before_skipping", DEFAULT_WAIT_FOR_PAST_DEPENDS_BEFORE_SKIPPING 

465 ) 

466 return bool(value) 

467 

468 @wait_for_past_depends_before_skipping.setter 

469 def wait_for_past_depends_before_skipping(self, value: bool) -> None: 

470 self.partial_kwargs["wait_for_past_depends_before_skipping"] = value 

471 

472 @property 

473 def wait_for_downstream(self) -> bool: 

474 return bool(self.partial_kwargs.get("wait_for_downstream")) 

475 

476 @wait_for_downstream.setter 

477 def wait_for_downstream(self, value: bool) -> None: 

478 self.partial_kwargs["wait_for_downstream"] = value 

479 

480 @property 

481 def retries(self) -> int: 

482 return self.partial_kwargs.get("retries", DEFAULT_RETRIES) 

483 

484 @retries.setter 

485 def retries(self, value: int) -> None: 

486 self.partial_kwargs["retries"] = value 

487 

488 @property 

489 def queue(self) -> str: 

490 return self.partial_kwargs.get("queue", DEFAULT_QUEUE) 

491 

492 @queue.setter 

493 def queue(self, value: str) -> None: 

494 self.partial_kwargs["queue"] = value 

495 

496 @property 

497 def pool(self) -> str: 

498 return self.partial_kwargs.get("pool", DEFAULT_POOL_NAME) 

499 

500 @pool.setter 

501 def pool(self, value: str) -> None: 

502 self.partial_kwargs["pool"] = value 

503 

504 @property 

505 def pool_slots(self) -> int: 

506 return self.partial_kwargs.get("pool_slots", DEFAULT_POOL_SLOTS) 

507 

508 @pool_slots.setter 

509 def pool_slots(self, value: int) -> None: 

510 self.partial_kwargs["pool_slots"] = value 

511 

512 @property 

513 def execution_timeout(self) -> datetime.timedelta | None: 

514 return self.partial_kwargs.get("execution_timeout") 

515 

516 @execution_timeout.setter 

517 def execution_timeout(self, value: datetime.timedelta | None) -> None: 

518 self.partial_kwargs["execution_timeout"] = value 

519 

520 @property 

521 def max_retry_delay(self) -> datetime.timedelta | None: 

522 return self.partial_kwargs.get("max_retry_delay") 

523 

524 @max_retry_delay.setter 

525 def max_retry_delay(self, value: datetime.timedelta | None) -> None: 

526 self.partial_kwargs["max_retry_delay"] = value 

527 

528 @property 

529 def retry_delay(self) -> datetime.timedelta: 

530 return self.partial_kwargs.get("retry_delay", DEFAULT_RETRY_DELAY) 

531 

532 @retry_delay.setter 

533 def retry_delay(self, value: datetime.timedelta) -> None: 

534 self.partial_kwargs["retry_delay"] = value 

535 

536 @property 

537 def retry_exponential_backoff(self) -> float: 

538 value = self.partial_kwargs.get("retry_exponential_backoff", 0) 

539 if value is True: 

540 return 2.0 

541 if value is False: 

542 return 0.0 

543 return float(value) 

544 

545 @retry_exponential_backoff.setter 

546 def retry_exponential_backoff(self, value: float) -> None: 

547 self.partial_kwargs["retry_exponential_backoff"] = value 

548 

549 @property 

550 def priority_weight(self) -> int: 

551 return self.partial_kwargs.get("priority_weight", DEFAULT_PRIORITY_WEIGHT) 

552 

553 @priority_weight.setter 

554 def priority_weight(self, value: int) -> None: 

555 self.partial_kwargs["priority_weight"] = value 

556 

557 @property 

558 def weight_rule(self) -> PriorityWeightStrategy: 

559 return validate_and_load_priority_weight_strategy( 

560 self.partial_kwargs.get("weight_rule", DEFAULT_WEIGHT_RULE) 

561 ) 

562 

563 @weight_rule.setter 

564 def weight_rule(self, value: str | PriorityWeightStrategy) -> None: 

565 self.partial_kwargs["weight_rule"] = validate_and_load_priority_weight_strategy(value) 

566 

567 @property 

568 def max_active_tis_per_dag(self) -> int | None: 

569 return self.partial_kwargs.get("max_active_tis_per_dag") 

570 

571 @max_active_tis_per_dag.setter 

572 def max_active_tis_per_dag(self, value: int | None) -> None: 

573 self.partial_kwargs["max_active_tis_per_dag"] = value 

574 

575 @property 

576 def max_active_tis_per_dagrun(self) -> int | None: 

577 return self.partial_kwargs.get("max_active_tis_per_dagrun") 

578 

579 @max_active_tis_per_dagrun.setter 

580 def max_active_tis_per_dagrun(self, value: int | None) -> None: 

581 self.partial_kwargs["max_active_tis_per_dagrun"] = value 

582 

583 @property 

584 def resources(self) -> Resources | None: 

585 return self.partial_kwargs.get("resources") 

586 

587 @property 

588 def on_execute_callback(self) -> TaskStateChangeCallbackAttrType: 

589 return self.partial_kwargs.get("on_execute_callback") or [] 

590 

591 @on_execute_callback.setter 

592 def on_execute_callback(self, value: TaskStateChangeCallbackAttrType) -> None: 

593 self.partial_kwargs["on_execute_callback"] = value or [] 

594 

595 @property 

596 def on_failure_callback(self) -> TaskStateChangeCallbackAttrType: 

597 return self.partial_kwargs.get("on_failure_callback") or [] 

598 

599 @on_failure_callback.setter 

600 def on_failure_callback(self, value: TaskStateChangeCallbackAttrType) -> None: 

601 self.partial_kwargs["on_failure_callback"] = value or [] 

602 

603 @property 

604 def on_retry_callback(self) -> TaskStateChangeCallbackAttrType: 

605 return self.partial_kwargs.get("on_retry_callback") or [] 

606 

607 @on_retry_callback.setter 

608 def on_retry_callback(self, value: TaskStateChangeCallbackAttrType) -> None: 

609 self.partial_kwargs["on_retry_callback"] = value or [] 

610 

611 @property 

612 def on_success_callback(self) -> TaskStateChangeCallbackAttrType: 

613 return self.partial_kwargs.get("on_success_callback") or [] 

614 

615 @on_success_callback.setter 

616 def on_success_callback(self, value: TaskStateChangeCallbackAttrType) -> None: 

617 self.partial_kwargs["on_success_callback"] = value or [] 

618 

619 @property 

620 def on_skipped_callback(self) -> TaskStateChangeCallbackAttrType: 

621 return self.partial_kwargs.get("on_skipped_callback") or [] 

622 

623 @on_skipped_callback.setter 

624 def on_skipped_callback(self, value: TaskStateChangeCallbackAttrType) -> None: 

625 self.partial_kwargs["on_skipped_callback"] = value or [] 

626 

627 @property 

628 def has_on_execute_callback(self) -> bool: 

629 return bool(self.on_execute_callback) 

630 

631 @property 

632 def has_on_failure_callback(self) -> bool: 

633 return bool(self.on_failure_callback) 

634 

635 @property 

636 def has_on_retry_callback(self) -> bool: 

637 return bool(self.on_retry_callback) 

638 

639 @property 

640 def has_on_success_callback(self) -> bool: 

641 return bool(self.on_success_callback) 

642 

643 @property 

644 def has_on_skipped_callback(self) -> bool: 

645 return bool(self.on_skipped_callback) 

646 

647 @property 

648 def run_as_user(self) -> str | None: 

649 return self.partial_kwargs.get("run_as_user") 

650 

651 @property 

652 def executor(self) -> str | None: 

653 return self.partial_kwargs.get("executor", DEFAULT_EXECUTOR) 

654 

655 @property 

656 def executor_config(self) -> dict: 

657 return self.partial_kwargs.get("executor_config", {}) 

658 

659 @property 

660 def inlets(self) -> list[Any]: 

661 return self.partial_kwargs.get("inlets", []) 

662 

663 @inlets.setter 

664 def inlets(self, value: list[Any]) -> None: 

665 self.partial_kwargs["inlets"] = value 

666 

667 @property 

668 def outlets(self) -> list[Any]: 

669 return self.partial_kwargs.get("outlets", []) 

670 

671 @outlets.setter 

672 def outlets(self, value: list[Any]) -> None: 

673 self.partial_kwargs["outlets"] = value 

674 

675 @property 

676 def doc(self) -> str | None: 

677 return self.partial_kwargs.get("doc") 

678 

679 @property 

680 def doc_md(self) -> str | None: 

681 return self.partial_kwargs.get("doc_md") 

682 

683 @property 

684 def doc_json(self) -> str | None: 

685 return self.partial_kwargs.get("doc_json") 

686 

687 @property 

688 def doc_yaml(self) -> str | None: 

689 return self.partial_kwargs.get("doc_yaml") 

690 

691 @property 

692 def doc_rst(self) -> str | None: 

693 return self.partial_kwargs.get("doc_rst") 

694 

695 @property 

696 def allow_nested_operators(self) -> bool: 

697 return bool(self.partial_kwargs.get("allow_nested_operators")) 

698 

699 def get_dag(self) -> DAG | None: 

700 """Implement Operator.""" 

701 return self.dag 

702 

703 @property 

704 def output(self) -> XComArg: 

705 """Return reference to XCom pushed by current operator.""" 

706 from airflow.sdk.definitions.xcom_arg import XComArg 

707 

708 return XComArg(operator=self) 

709 

710 def serialize_for_task_group(self) -> tuple[DagAttributeTypes, Any]: 

711 """Implement DAGNode.""" 

712 return DagAttributeTypes.OP, self.task_id 

713 

714 def _expand_mapped_kwargs(self, context: Mapping[str, Any]) -> tuple[Mapping[str, Any], set[int]]: 

715 """ 

716 Get the kwargs to create the unmapped operator. 

717 

718 This exists because taskflow operators expand against op_kwargs, not the 

719 entire operator kwargs dict. 

720 """ 

721 return self._get_specified_expand_input().resolve(context) 

722 

723 def _get_unmap_kwargs(self, mapped_kwargs: Mapping[str, Any], *, strict: bool) -> dict[str, Any]: 

724 """ 

725 Get init kwargs to unmap the underlying operator class. 

726 

727 :param mapped_kwargs: The dict returned by ``_expand_mapped_kwargs``. 

728 """ 

729 if strict: 

730 prevent_duplicates( 

731 self.partial_kwargs, 

732 mapped_kwargs, 

733 fail_reason="unmappable or already specified", 

734 ) 

735 

736 # If params appears in the mapped kwargs, we need to merge it into the 

737 # partial params, overriding existing keys. 

738 params = copy.copy(self.params) 

739 with contextlib.suppress(KeyError): 

740 params.update(mapped_kwargs["params"]) 

741 

742 # Ordering is significant; mapped kwargs should override partial ones, 

743 # and the specially handled params should be respected. 

744 return { 

745 "task_id": self.task_id, 

746 "dag": self.dag, 

747 "task_group": self.task_group, 

748 "start_date": self.start_date, 

749 "end_date": self.end_date, 

750 **self.partial_kwargs, 

751 **mapped_kwargs, 

752 "params": params, 

753 } 

754 

755 def unmap(self, resolve: None | Mapping[str, Any]) -> BaseOperator: 

756 """ 

757 Get the "normal" Operator after applying the current mapping. 

758 

759 :meta private: 

760 """ 

761 if isinstance(resolve, Mapping): 

762 kwargs = resolve 

763 elif resolve is not None: 

764 kwargs, _ = self._expand_mapped_kwargs(*resolve) 

765 else: 

766 raise RuntimeError("cannot unmap a non-serialized operator without context") 

767 kwargs = self._get_unmap_kwargs(kwargs, strict=self._disallow_kwargs_override) 

768 is_setup = kwargs.pop("is_setup", False) 

769 is_teardown = kwargs.pop("is_teardown", False) 

770 on_failure_fail_dagrun = kwargs.pop("on_failure_fail_dagrun", False) 

771 kwargs["task_id"] = self.task_id 

772 op = self.operator_class(**kwargs, _airflow_from_mapped=True) 

773 op.is_setup = is_setup 

774 op.is_teardown = is_teardown 

775 op.on_failure_fail_dagrun = on_failure_fail_dagrun 

776 op.downstream_task_ids = self.downstream_task_ids 

777 op.upstream_task_ids = self.upstream_task_ids 

778 return op 

779 

780 def _get_specified_expand_input(self) -> ExpandInput: 

781 """Input received from the expand call on the operator.""" 

782 return getattr(self, self._expand_input_attr) 

783 

784 def prepare_for_execution(self) -> MappedOperator: 

785 # Since a mapped operator cannot be used for execution, and an unmapped 

786 # BaseOperator needs to be created later (see render_template_fields), 

787 # we don't need to create a copy of the MappedOperator here. 

788 return self 

789 

790 # TODO (GH-52141): Do we need this in the SDK? 

791 def iter_mapped_dependencies(self) -> Iterator[AbstractOperator]: 

792 """Upstream dependencies that provide XComs used by this task for task mapping.""" 

793 from airflow.sdk.definitions.xcom_arg import XComArg 

794 

795 for operator, _ in XComArg.iter_xcom_references(self._get_specified_expand_input()): 

796 yield operator 

797 

798 def render_template_fields( 

799 self, 

800 context: Context, 

801 jinja_env: jinja2.Environment | None = None, 

802 ) -> None: 

803 """ 

804 Template all attributes listed in *self.template_fields*. 

805 

806 This updates *context* to reference the map-expanded task and relevant 

807 information, without modifying the mapped operator. The expanded task 

808 in *context* is then rendered in-place. 

809 

810 :param context: Context dict with values to apply on content. 

811 :param jinja_env: Jinja environment to use for rendering. 

812 """ 

813 from airflow.sdk.execution_time.context import context_update_for_unmapped 

814 

815 if not jinja_env: 

816 jinja_env = self.get_template_env() 

817 

818 mapped_kwargs, seen_oids = self._expand_mapped_kwargs(context) 

819 unmapped_task = self.unmap(mapped_kwargs) 

820 context_update_for_unmapped(context, unmapped_task) 

821 

822 # Since the operators that extend `BaseOperator` are not subclasses of 

823 # `MappedOperator`, we need to call `_do_render_template_fields` from 

824 # the unmapped task in order to call the operator method when we override 

825 # it to customize the parsing of nested fields. 

826 unmapped_task._do_render_template_fields( 

827 parent=unmapped_task, 

828 template_fields=self.template_fields, 

829 context=context, 

830 jinja_env=jinja_env, 

831 seen_oids=seen_oids, 

832 ) 

833 

834 def expand_start_trigger_args(self, *, context: Context) -> StartTriggerArgs | None: 

835 """ 

836 Get the kwargs to create the unmapped start_trigger_args. 

837 

838 This method is for allowing mapped operator to start execution from triggerer. 

839 """ 

840 from airflow.triggers.base import StartTriggerArgs 

841 

842 if not self.start_trigger_args: 

843 return None 

844 

845 mapped_kwargs, _ = self._expand_mapped_kwargs(context) 

846 if self._disallow_kwargs_override: 

847 prevent_duplicates( 

848 self.partial_kwargs, 

849 mapped_kwargs, 

850 fail_reason="unmappable or already specified", 

851 ) 

852 

853 # Ordering is significant; mapped kwargs should override partial ones. 

854 trigger_kwargs = mapped_kwargs.get( 

855 "trigger_kwargs", 

856 self.partial_kwargs.get("trigger_kwargs", self.start_trigger_args.trigger_kwargs), 

857 ) 

858 next_kwargs = mapped_kwargs.get( 

859 "next_kwargs", 

860 self.partial_kwargs.get("next_kwargs", self.start_trigger_args.next_kwargs), 

861 ) 

862 timeout = mapped_kwargs.get( 

863 "trigger_timeout", self.partial_kwargs.get("trigger_timeout", self.start_trigger_args.timeout) 

864 ) 

865 return StartTriggerArgs( 

866 trigger_cls=self.start_trigger_args.trigger_cls, 

867 trigger_kwargs=trigger_kwargs, 

868 next_method=self.start_trigger_args.next_method, 

869 next_kwargs=next_kwargs, 

870 timeout=timeout, 

871 )