Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/airflow/sdk/definitions/mappedoperator.py: 49%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

487 statements  

1# 

2# Licensed to the Apache Software Foundation (ASF) under one 

3# or more contributor license agreements. See the NOTICE file 

4# distributed with this work for additional information 

5# regarding copyright ownership. The ASF licenses this file 

6# to you under the Apache License, Version 2.0 (the 

7# "License"); you may not use this file except in compliance 

8# with the License. You may obtain a copy of the License at 

9# 

10# http://www.apache.org/licenses/LICENSE-2.0 

11# 

12# Unless required by applicable law or agreed to in writing, 

13# software distributed under the License is distributed on an 

14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 

15# KIND, either express or implied. See the License for the 

16# specific language governing permissions and limitations 

17# under the License. 

18from __future__ import annotations 

19 

20import contextlib 

21import copy 

22import warnings 

23from collections.abc import Collection, Iterable, Iterator, Mapping, Sequence 

24from typing import TYPE_CHECKING, Any, ClassVar, Literal, TypeGuard 

25 

26import attrs 

27import methodtools 

28from lazy_object_proxy import Proxy 

29 

30from airflow.sdk.bases.xcom import BaseXCom 

31from airflow.sdk.definitions._internal.abstractoperator import ( 

32 DEFAULT_EXECUTOR, 

33 DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST, 

34 DEFAULT_OWNER, 

35 DEFAULT_POOL_NAME, 

36 DEFAULT_POOL_SLOTS, 

37 DEFAULT_PRIORITY_WEIGHT, 

38 DEFAULT_QUEUE, 

39 DEFAULT_RETRIES, 

40 DEFAULT_RETRY_DELAY, 

41 DEFAULT_TRIGGER_RULE, 

42 DEFAULT_WAIT_FOR_PAST_DEPENDS_BEFORE_SKIPPING, 

43 DEFAULT_WEIGHT_RULE, 

44 AbstractOperator, 

45 TaskStateChangeCallbackAttrType, 

46) 

47from airflow.sdk.definitions._internal.expandinput import ( 

48 DictOfListsExpandInput, 

49 ListOfDictsExpandInput, 

50 is_mappable, 

51) 

52from airflow.sdk.definitions._internal.types import NOTSET 

53from airflow.serialization.enums import DagAttributeTypes 

54 

55if TYPE_CHECKING: 

56 import datetime 

57 

58 import jinja2 # Slow import. 

59 import pendulum 

60 

61 from airflow.sdk import DAG, BaseOperator, BaseOperatorLink, Context, TaskGroup, TriggerRule, XComArg 

62 from airflow.sdk.definitions._internal.expandinput import ( 

63 ExpandInput, 

64 OperatorExpandArgument, 

65 OperatorExpandKwargsArgument, 

66 ) 

67 from airflow.sdk.definitions.operator_resources import Resources 

68 from airflow.sdk.definitions.param import ParamsDict 

69 from airflow.task.priority_strategy import PriorityWeightStrategy 

70 from airflow.triggers.base import StartTriggerArgs 

71 

72ValidationSource = Literal["expand"] | Literal["partial"] 

73 

74 

75def validate_mapping_kwargs(op: type[BaseOperator], func: ValidationSource, value: dict[str, Any]) -> None: 

76 # use a dict so order of args is same as code order 

77 unknown_args = value.copy() 

78 for klass in op.mro(): 

79 init = klass.__init__ # type: ignore[misc] 

80 try: 

81 param_names = init._BaseOperatorMeta__param_names 

82 except AttributeError: 

83 continue 

84 for name in param_names: 

85 value = unknown_args.pop(name, NOTSET) 

86 if func != "expand": 

87 continue 

88 if value is NOTSET: 

89 continue 

90 if is_mappable(value): 

91 continue 

92 type_name = type(value).__name__ 

93 error = f"{op.__name__}.expand() got an unexpected type {type_name!r} for keyword argument {name}" 

94 raise ValueError(error) 

95 if not unknown_args: 

96 return # If we have no args left to check: stop looking at the MRO chain. 

97 

98 if len(unknown_args) == 1: 

99 error = f"an unexpected keyword argument {unknown_args.popitem()[0]!r}" 

100 else: 

101 names = ", ".join(repr(n) for n in unknown_args) 

102 error = f"unexpected keyword arguments {names}" 

103 raise TypeError(f"{op.__name__}.{func}() got {error}") 

104 

105 

106def _is_container(obj: Any) -> bool: 

107 """Test if an object is a container (iterable) but not a string.""" 

108 if isinstance(obj, Proxy): 

109 # Proxy of any object is considered a container because it implements __iter__ 

110 # to forward the call to the lazily initialized object 

111 # Unwrap Proxy before checking __iter__ to evaluate the proxied object 

112 obj = obj.__wrapped__ 

113 return hasattr(obj, "__iter__") and not isinstance(obj, str) 

114 

115 

116def ensure_xcomarg_return_value(arg: Any) -> None: 

117 from airflow.sdk.definitions.xcom_arg import XComArg 

118 

119 if isinstance(arg, XComArg): 

120 for operator, key in arg.iter_references(): 

121 if key != BaseXCom.XCOM_RETURN_KEY: 

122 raise ValueError(f"cannot map over XCom with custom key {key!r} from {operator}") 

123 elif not _is_container(arg): 

124 return 

125 elif isinstance(arg, Mapping): 

126 for v in arg.values(): 

127 ensure_xcomarg_return_value(v) 

128 elif isinstance(arg, Iterable): 

129 for v in arg: 

130 ensure_xcomarg_return_value(v) 

131 

132 

133def is_mappable_value(value: Any) -> TypeGuard[Collection]: 

134 """ 

135 Whether a value can be used for task mapping. 

136 

137 We only allow collections with guaranteed ordering, but exclude character 

138 sequences since that's usually not what users would expect to be mappable. 

139 

140 :meta private: 

141 """ 

142 if not isinstance(value, (Sequence, dict)): 

143 return False 

144 if isinstance(value, (bytearray, bytes, str)): 

145 return False 

146 return True 

147 

148 

149def prevent_duplicates(kwargs1: dict[str, Any], kwargs2: Mapping[str, Any], *, fail_reason: str) -> None: 

150 """ 

151 Ensure *kwargs1* and *kwargs2* do not contain common keys. 

152 

153 :raises TypeError: If common keys are found. 

154 """ 

155 duplicated_keys = set(kwargs1).intersection(kwargs2) 

156 if not duplicated_keys: 

157 return 

158 if len(duplicated_keys) == 1: 

159 raise TypeError(f"{fail_reason} argument: {duplicated_keys.pop()}") 

160 duplicated_keys_display = ", ".join(sorted(duplicated_keys)) 

161 raise TypeError(f"{fail_reason} arguments: {duplicated_keys_display}") 

162 

163 

164@attrs.define(kw_only=True, repr=False) 

165class OperatorPartial: 

166 """ 

167 An "intermediate state" returned by ``BaseOperator.partial()``. 

168 

169 This only exists at Dag-parsing time; the only intended usage is for the 

170 user to call ``.expand()`` on it at some point (usually in a method chain) to 

171 create a ``MappedOperator`` to add into the Dag. 

172 """ 

173 

174 operator_class: type[BaseOperator] 

175 kwargs: dict[str, Any] 

176 params: ParamsDict | dict 

177 

178 _expand_called: bool = False # Set when expand() is called to ease user debugging. 

179 

180 def __attrs_post_init__(self): 

181 validate_mapping_kwargs(self.operator_class, "partial", self.kwargs) 

182 

183 def __repr__(self) -> str: 

184 args = ", ".join(f"{k}={v!r}" for k, v in self.kwargs.items()) 

185 return f"{self.operator_class.__name__}.partial({args})" 

186 

187 def __del__(self): 

188 if not self._expand_called: 

189 try: 

190 task_id = repr(self.kwargs["task_id"]) 

191 except KeyError: 

192 task_id = f"at {hex(id(self))}" 

193 warnings.warn(f"Task {task_id} was never mapped!", category=UserWarning, stacklevel=1) 

194 

195 def expand(self, **mapped_kwargs: OperatorExpandArgument) -> MappedOperator: 

196 if not mapped_kwargs: 

197 raise TypeError("no arguments to expand against") 

198 validate_mapping_kwargs(self.operator_class, "expand", mapped_kwargs) 

199 prevent_duplicates(self.kwargs, mapped_kwargs, fail_reason="unmappable or already specified") 

200 # Since the input is already checked at parse time, we can set strict 

201 # to False to skip the checks on execution. 

202 return self._expand(DictOfListsExpandInput(mapped_kwargs), strict=False) 

203 

204 def expand_kwargs(self, kwargs: OperatorExpandKwargsArgument, *, strict: bool = True) -> MappedOperator: 

205 from airflow.sdk.definitions.xcom_arg import XComArg 

206 

207 if isinstance(kwargs, Sequence): 

208 for item in kwargs: 

209 if not isinstance(item, (XComArg, Mapping)): 

210 raise TypeError(f"expected XComArg or list[dict], not {type(kwargs).__name__}") 

211 elif not isinstance(kwargs, XComArg): 

212 raise TypeError(f"expected XComArg or list[dict], not {type(kwargs).__name__}") 

213 return self._expand(ListOfDictsExpandInput(kwargs), strict=strict) 

214 

215 def _expand(self, expand_input: ExpandInput, *, strict: bool) -> MappedOperator: 

216 from airflow.providers.standard.operators.empty import EmptyOperator 

217 from airflow.providers.standard.utils.skipmixin import SkipMixin 

218 from airflow.sdk import BaseSensorOperator 

219 

220 self._expand_called = True 

221 ensure_xcomarg_return_value(expand_input.value) 

222 

223 partial_kwargs = self.kwargs.copy() 

224 task_id = partial_kwargs.pop("task_id") 

225 dag = partial_kwargs.pop("dag") 

226 task_group = partial_kwargs.pop("task_group") 

227 start_date = partial_kwargs.pop("start_date", None) 

228 end_date = partial_kwargs.pop("end_date", None) 

229 

230 try: 

231 operator_name = self.operator_class.custom_operator_name # type: ignore 

232 except AttributeError: 

233 operator_name = self.operator_class.__name__ 

234 

235 op = MappedOperator( 

236 operator_class=self.operator_class, 

237 expand_input=expand_input, 

238 partial_kwargs=partial_kwargs, 

239 task_id=task_id, 

240 params=self.params, 

241 operator_extra_links=self.operator_class.operator_extra_links, 

242 template_ext=self.operator_class.template_ext, 

243 template_fields=self.operator_class.template_fields, 

244 template_fields_renderers=self.operator_class.template_fields_renderers, 

245 ui_color=self.operator_class.ui_color, 

246 ui_fgcolor=self.operator_class.ui_fgcolor, 

247 is_empty=issubclass(self.operator_class, EmptyOperator), 

248 is_sensor=issubclass(self.operator_class, BaseSensorOperator), 

249 can_skip_downstream=issubclass(self.operator_class, SkipMixin), 

250 task_module=self.operator_class.__module__, 

251 task_type=self.operator_class.__name__, 

252 operator_name=operator_name, 

253 dag=dag, 

254 task_group=task_group, 

255 start_date=start_date, 

256 end_date=end_date, 

257 disallow_kwargs_override=strict, 

258 # For classic operators, this points to expand_input because kwargs 

259 # to BaseOperator.expand() contribute to operator arguments. 

260 expand_input_attr="expand_input", 

261 # TODO: Move these to task SDK's BaseOperator and remove getattr 

262 start_trigger_args=getattr(self.operator_class, "start_trigger_args", None), 

263 start_from_trigger=bool(getattr(self.operator_class, "start_from_trigger", False)), 

264 ) 

265 return op 

266 

267 

268@attrs.define( 

269 kw_only=True, 

270 # Disable custom __getstate__ and __setstate__ generation since it interacts 

271 # badly with Airflow's Dag serialization and pickling. When a mapped task is 

272 # deserialized, subclasses are coerced into MappedOperator, but when it goes 

273 # through Dag pickling, all attributes defined in the subclasses are dropped 

274 # by attrs's custom state management. Since attrs does not do anything too 

275 # special here (the logic is only important for slots=True), we use Python's 

276 # built-in implementation, which works (as proven by good old BaseOperator). 

277 getstate_setstate=False, 

278) 

279class MappedOperator(AbstractOperator): 

280 """Object representing a mapped operator in a Dag.""" 

281 

282 operator_class: type[BaseOperator] 

283 

284 _is_mapped: bool = attrs.field(init=False, default=True) 

285 

286 expand_input: ExpandInput 

287 partial_kwargs: dict[str, Any] 

288 

289 # Needed for serialization. 

290 task_id: str 

291 params: ParamsDict | dict 

292 operator_extra_links: Collection[BaseOperatorLink] 

293 template_ext: Sequence[str] 

294 template_fields: Collection[str] 

295 template_fields_renderers: dict[str, str] 

296 ui_color: str 

297 ui_fgcolor: str 

298 _is_empty: bool = attrs.field(alias="is_empty") 

299 _can_skip_downstream: bool = attrs.field(alias="can_skip_downstream") 

300 _is_sensor: bool = attrs.field(alias="is_sensor", default=False) 

301 _task_module: str 

302 task_type: str 

303 _operator_name: str 

304 start_trigger_args: StartTriggerArgs | None 

305 start_from_trigger: bool 

306 _needs_expansion: bool = True 

307 

308 dag: DAG | None 

309 task_group: TaskGroup | None 

310 start_date: pendulum.DateTime | None 

311 end_date: pendulum.DateTime | None 

312 upstream_task_ids: set[str] = attrs.field(factory=set, init=False) 

313 downstream_task_ids: set[str] = attrs.field(factory=set, init=False) 

314 

315 _disallow_kwargs_override: bool 

316 """Whether execution fails if ``expand_input`` has duplicates to ``partial_kwargs``. 

317 

318 If *False*, values from ``expand_input`` under duplicate keys override those 

319 under corresponding keys in ``partial_kwargs``. 

320 """ 

321 

322 _expand_input_attr: str 

323 """Where to get kwargs to calculate expansion length against. 

324 

325 This should be a name to call ``getattr()`` on. 

326 """ 

327 

328 HIDE_ATTRS_FROM_UI: ClassVar[frozenset[str]] = AbstractOperator.HIDE_ATTRS_FROM_UI | frozenset( 

329 ("parse_time_mapped_ti_count", "operator_class", "start_trigger_args", "start_from_trigger") 

330 ) 

331 

332 def __hash__(self): 

333 return id(self) 

334 

335 def __repr__(self): 

336 return f"<Mapped({self.task_type}): {self.task_id}>" 

337 

338 def __attrs_post_init__(self): 

339 from airflow.sdk.definitions.xcom_arg import XComArg 

340 

341 if self.get_closest_mapped_task_group() is not None: 

342 raise NotImplementedError("operator expansion in an expanded task group is not yet supported") 

343 

344 if self.task_group: 

345 self.task_group.add(self) 

346 if self.dag: 

347 self.dag.add_task(self) 

348 XComArg.apply_upstream_relationship(self, self._get_specified_expand_input().value) 

349 for k, v in self.partial_kwargs.items(): 

350 if k in self.template_fields: 

351 XComArg.apply_upstream_relationship(self, v) 

352 

353 @methodtools.lru_cache(maxsize=None) 

354 @classmethod 

355 def get_serialized_fields(cls): 

356 # Not using 'cls' here since we only want to serialize base fields. 

357 return (frozenset(attrs.fields_dict(MappedOperator))) - { 

358 "_is_empty", 

359 "_can_skip_downstream", 

360 "dag", 

361 "deps", 

362 "expand_input", # This is needed to be able to accept XComArg. 

363 "task_group", 

364 "upstream_task_ids", 

365 "_is_setup", 

366 "_is_teardown", 

367 "_on_failure_fail_dagrun", 

368 "operator_class", 

369 "_needs_expansion", 

370 "partial_kwargs", 

371 "operator_extra_links", 

372 } 

373 

374 @property 

375 def operator_name(self) -> str: 

376 return self._operator_name 

377 

378 @property 

379 def roots(self) -> Sequence[AbstractOperator]: 

380 """Implementing DAGNode.""" 

381 return [self] 

382 

383 @property 

384 def leaves(self) -> Sequence[AbstractOperator]: 

385 """Implementing DAGNode.""" 

386 return [self] 

387 

388 @property 

389 def task_display_name(self) -> str: 

390 return self.partial_kwargs.get("task_display_name") or self.task_id 

391 

392 @property 

393 def owner(self) -> str: 

394 return self.partial_kwargs.get("owner", DEFAULT_OWNER) 

395 

396 @owner.setter 

397 def owner(self, value: str) -> None: 

398 self.partial_kwargs["owner"] = value 

399 

400 @property 

401 def email(self) -> None | str | Iterable[str]: 

402 return self.partial_kwargs.get("email") 

403 

404 @property 

405 def email_on_failure(self) -> bool: 

406 return self.partial_kwargs.get("email_on_failure", True) 

407 

408 @property 

409 def email_on_retry(self) -> bool: 

410 return self.partial_kwargs.get("email_on_retry", True) 

411 

412 @property 

413 def map_index_template(self) -> None | str: 

414 return self.partial_kwargs.get("map_index_template") 

415 

416 @map_index_template.setter 

417 def map_index_template(self, value: str | None) -> None: 

418 self.partial_kwargs["map_index_template"] = value 

419 

420 @property 

421 def trigger_rule(self) -> TriggerRule: 

422 return self.partial_kwargs.get("trigger_rule", DEFAULT_TRIGGER_RULE) 

423 

424 @trigger_rule.setter 

425 def trigger_rule(self, value): 

426 self.partial_kwargs["trigger_rule"] = value 

427 

428 @property 

429 def is_setup(self) -> bool: 

430 return bool(self.partial_kwargs.get("is_setup")) 

431 

432 @is_setup.setter 

433 def is_setup(self, value: bool) -> None: 

434 self.partial_kwargs["is_setup"] = value 

435 

436 @property 

437 def is_teardown(self) -> bool: 

438 return bool(self.partial_kwargs.get("is_teardown")) 

439 

440 @is_teardown.setter 

441 def is_teardown(self, value: bool) -> None: 

442 self.partial_kwargs["is_teardown"] = value 

443 

444 @property 

445 def depends_on_past(self) -> bool: 

446 return bool(self.partial_kwargs.get("depends_on_past")) 

447 

448 @depends_on_past.setter 

449 def depends_on_past(self, value: bool) -> None: 

450 self.partial_kwargs["depends_on_past"] = value 

451 

452 @property 

453 def ignore_first_depends_on_past(self) -> bool: 

454 value = self.partial_kwargs.get("ignore_first_depends_on_past", DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST) 

455 return bool(value) 

456 

457 @ignore_first_depends_on_past.setter 

458 def ignore_first_depends_on_past(self, value: bool) -> None: 

459 self.partial_kwargs["ignore_first_depends_on_past"] = value 

460 

461 @property 

462 def wait_for_past_depends_before_skipping(self) -> bool: 

463 value = self.partial_kwargs.get( 

464 "wait_for_past_depends_before_skipping", DEFAULT_WAIT_FOR_PAST_DEPENDS_BEFORE_SKIPPING 

465 ) 

466 return bool(value) 

467 

468 @wait_for_past_depends_before_skipping.setter 

469 def wait_for_past_depends_before_skipping(self, value: bool) -> None: 

470 self.partial_kwargs["wait_for_past_depends_before_skipping"] = value 

471 

472 @property 

473 def wait_for_downstream(self) -> bool: 

474 return bool(self.partial_kwargs.get("wait_for_downstream")) 

475 

476 @wait_for_downstream.setter 

477 def wait_for_downstream(self, value: bool) -> None: 

478 self.partial_kwargs["wait_for_downstream"] = value 

479 

480 @property 

481 def retries(self) -> int: 

482 return self.partial_kwargs.get("retries", DEFAULT_RETRIES) 

483 

484 @retries.setter 

485 def retries(self, value: int) -> None: 

486 self.partial_kwargs["retries"] = value 

487 

488 @property 

489 def queue(self) -> str: 

490 return self.partial_kwargs.get("queue", DEFAULT_QUEUE) 

491 

492 @queue.setter 

493 def queue(self, value: str) -> None: 

494 self.partial_kwargs["queue"] = value 

495 

496 @property 

497 def pool(self) -> str: 

498 return self.partial_kwargs.get("pool", DEFAULT_POOL_NAME) 

499 

500 @pool.setter 

501 def pool(self, value: str) -> None: 

502 self.partial_kwargs["pool"] = value 

503 

504 @property 

505 def pool_slots(self) -> int: 

506 return self.partial_kwargs.get("pool_slots", DEFAULT_POOL_SLOTS) 

507 

508 @pool_slots.setter 

509 def pool_slots(self, value: int) -> None: 

510 self.partial_kwargs["pool_slots"] = value 

511 

512 @property 

513 def execution_timeout(self) -> datetime.timedelta | None: 

514 return self.partial_kwargs.get("execution_timeout") 

515 

516 @execution_timeout.setter 

517 def execution_timeout(self, value: datetime.timedelta | None) -> None: 

518 self.partial_kwargs["execution_timeout"] = value 

519 

520 @property 

521 def max_retry_delay(self) -> datetime.timedelta | None: 

522 return self.partial_kwargs.get("max_retry_delay") 

523 

524 @max_retry_delay.setter 

525 def max_retry_delay(self, value: datetime.timedelta | None) -> None: 

526 self.partial_kwargs["max_retry_delay"] = value 

527 

528 @property 

529 def retry_delay(self) -> datetime.timedelta: 

530 return self.partial_kwargs.get("retry_delay", DEFAULT_RETRY_DELAY) 

531 

532 @retry_delay.setter 

533 def retry_delay(self, value: datetime.timedelta) -> None: 

534 self.partial_kwargs["retry_delay"] = value 

535 

536 @property 

537 def retry_exponential_backoff(self) -> float: 

538 value = self.partial_kwargs.get("retry_exponential_backoff", 0) 

539 if value is True: 

540 return 2.0 

541 if value is False: 

542 return 0.0 

543 return float(value) 

544 

545 @retry_exponential_backoff.setter 

546 def retry_exponential_backoff(self, value: float) -> None: 

547 self.partial_kwargs["retry_exponential_backoff"] = value 

548 

549 @property 

550 def priority_weight(self) -> int: 

551 return self.partial_kwargs.get("priority_weight", DEFAULT_PRIORITY_WEIGHT) 

552 

553 @priority_weight.setter 

554 def priority_weight(self, value: int) -> None: 

555 self.partial_kwargs["priority_weight"] = value 

556 

557 @property 

558 def weight_rule(self) -> PriorityWeightStrategy: 

559 return self.partial_kwargs.get("weight_rule", DEFAULT_WEIGHT_RULE) 

560 

561 @weight_rule.setter 

562 def weight_rule(self, value: str | PriorityWeightStrategy) -> None: 

563 self.partial_kwargs["weight_rule"] = value 

564 

565 @property 

566 def max_active_tis_per_dag(self) -> int | None: 

567 return self.partial_kwargs.get("max_active_tis_per_dag") 

568 

569 @max_active_tis_per_dag.setter 

570 def max_active_tis_per_dag(self, value: int | None) -> None: 

571 self.partial_kwargs["max_active_tis_per_dag"] = value 

572 

573 @property 

574 def max_active_tis_per_dagrun(self) -> int | None: 

575 return self.partial_kwargs.get("max_active_tis_per_dagrun") 

576 

577 @max_active_tis_per_dagrun.setter 

578 def max_active_tis_per_dagrun(self, value: int | None) -> None: 

579 self.partial_kwargs["max_active_tis_per_dagrun"] = value 

580 

581 @property 

582 def resources(self) -> Resources | None: 

583 return self.partial_kwargs.get("resources") 

584 

585 @property 

586 def on_execute_callback(self) -> TaskStateChangeCallbackAttrType: 

587 return self.partial_kwargs.get("on_execute_callback") or [] 

588 

589 @on_execute_callback.setter 

590 def on_execute_callback(self, value: TaskStateChangeCallbackAttrType) -> None: 

591 self.partial_kwargs["on_execute_callback"] = value or [] 

592 

593 @property 

594 def on_failure_callback(self) -> TaskStateChangeCallbackAttrType: 

595 return self.partial_kwargs.get("on_failure_callback") or [] 

596 

597 @on_failure_callback.setter 

598 def on_failure_callback(self, value: TaskStateChangeCallbackAttrType) -> None: 

599 self.partial_kwargs["on_failure_callback"] = value or [] 

600 

601 @property 

602 def on_retry_callback(self) -> TaskStateChangeCallbackAttrType: 

603 return self.partial_kwargs.get("on_retry_callback") or [] 

604 

605 @on_retry_callback.setter 

606 def on_retry_callback(self, value: TaskStateChangeCallbackAttrType) -> None: 

607 self.partial_kwargs["on_retry_callback"] = value or [] 

608 

609 @property 

610 def on_success_callback(self) -> TaskStateChangeCallbackAttrType: 

611 return self.partial_kwargs.get("on_success_callback") or [] 

612 

613 @on_success_callback.setter 

614 def on_success_callback(self, value: TaskStateChangeCallbackAttrType) -> None: 

615 self.partial_kwargs["on_success_callback"] = value or [] 

616 

617 @property 

618 def on_skipped_callback(self) -> TaskStateChangeCallbackAttrType: 

619 return self.partial_kwargs.get("on_skipped_callback") or [] 

620 

621 @on_skipped_callback.setter 

622 def on_skipped_callback(self, value: TaskStateChangeCallbackAttrType) -> None: 

623 self.partial_kwargs["on_skipped_callback"] = value or [] 

624 

625 @property 

626 def has_on_execute_callback(self) -> bool: 

627 return bool(self.on_execute_callback) 

628 

629 @property 

630 def has_on_failure_callback(self) -> bool: 

631 return bool(self.on_failure_callback) 

632 

633 @property 

634 def has_on_retry_callback(self) -> bool: 

635 return bool(self.on_retry_callback) 

636 

637 @property 

638 def has_on_success_callback(self) -> bool: 

639 return bool(self.on_success_callback) 

640 

641 @property 

642 def has_on_skipped_callback(self) -> bool: 

643 return bool(self.on_skipped_callback) 

644 

645 @property 

646 def run_as_user(self) -> str | None: 

647 return self.partial_kwargs.get("run_as_user") 

648 

649 @property 

650 def executor(self) -> str | None: 

651 return self.partial_kwargs.get("executor", DEFAULT_EXECUTOR) 

652 

653 @property 

654 def executor_config(self) -> dict: 

655 return self.partial_kwargs.get("executor_config", {}) 

656 

657 @property 

658 def inlets(self) -> list[Any]: 

659 return self.partial_kwargs.get("inlets", []) 

660 

661 @inlets.setter 

662 def inlets(self, value: list[Any]) -> None: 

663 self.partial_kwargs["inlets"] = value 

664 

665 @property 

666 def outlets(self) -> list[Any]: 

667 return self.partial_kwargs.get("outlets", []) 

668 

669 @outlets.setter 

670 def outlets(self, value: list[Any]) -> None: 

671 self.partial_kwargs["outlets"] = value 

672 

673 @property 

674 def doc(self) -> str | None: 

675 return self.partial_kwargs.get("doc") 

676 

677 @property 

678 def doc_md(self) -> str | None: 

679 return self.partial_kwargs.get("doc_md") 

680 

681 @property 

682 def doc_json(self) -> str | None: 

683 return self.partial_kwargs.get("doc_json") 

684 

685 @property 

686 def doc_yaml(self) -> str | None: 

687 return self.partial_kwargs.get("doc_yaml") 

688 

689 @property 

690 def doc_rst(self) -> str | None: 

691 return self.partial_kwargs.get("doc_rst") 

692 

693 @property 

694 def allow_nested_operators(self) -> bool: 

695 return bool(self.partial_kwargs.get("allow_nested_operators")) 

696 

697 def get_dag(self) -> DAG | None: 

698 """Implement Operator.""" 

699 return self.dag 

700 

701 @property 

702 def output(self) -> XComArg: 

703 """Return reference to XCom pushed by current operator.""" 

704 from airflow.sdk.definitions.xcom_arg import XComArg 

705 

706 return XComArg(operator=self) 

707 

708 def serialize_for_task_group(self) -> tuple[DagAttributeTypes, Any]: 

709 """Implement DAGNode.""" 

710 return DagAttributeTypes.OP, self.task_id 

711 

712 def _expand_mapped_kwargs(self, context: Mapping[str, Any]) -> tuple[Mapping[str, Any], set[int]]: 

713 """ 

714 Get the kwargs to create the unmapped operator. 

715 

716 This exists because taskflow operators expand against op_kwargs, not the 

717 entire operator kwargs dict. 

718 """ 

719 return self._get_specified_expand_input().resolve(context) 

720 

721 def _get_unmap_kwargs(self, mapped_kwargs: Mapping[str, Any], *, strict: bool) -> dict[str, Any]: 

722 """ 

723 Get init kwargs to unmap the underlying operator class. 

724 

725 :param mapped_kwargs: The dict returned by ``_expand_mapped_kwargs``. 

726 """ 

727 if strict: 

728 prevent_duplicates( 

729 self.partial_kwargs, 

730 mapped_kwargs, 

731 fail_reason="unmappable or already specified", 

732 ) 

733 

734 # If params appears in the mapped kwargs, we need to merge it into the 

735 # partial params, overriding existing keys. 

736 params = copy.copy(self.params) 

737 with contextlib.suppress(KeyError): 

738 params.update(mapped_kwargs["params"]) 

739 

740 # Ordering is significant; mapped kwargs should override partial ones, 

741 # and the specially handled params should be respected. 

742 return { 

743 "task_id": self.task_id, 

744 "dag": self.dag, 

745 "task_group": self.task_group, 

746 "start_date": self.start_date, 

747 "end_date": self.end_date, 

748 **self.partial_kwargs, 

749 **mapped_kwargs, 

750 "params": params, 

751 } 

752 

753 def unmap(self, resolve: Mapping[str, Any]) -> BaseOperator: 

754 """ 

755 Get the "normal" Operator after applying the current mapping. 

756 

757 :meta private: 

758 """ 

759 kwargs = self._get_unmap_kwargs(resolve, strict=self._disallow_kwargs_override) 

760 is_setup = kwargs.pop("is_setup", False) 

761 is_teardown = kwargs.pop("is_teardown", False) 

762 on_failure_fail_dagrun = kwargs.pop("on_failure_fail_dagrun", False) 

763 kwargs["task_id"] = self.task_id 

764 op = self.operator_class(**kwargs, _airflow_from_mapped=True) 

765 op.is_setup = is_setup 

766 op.is_teardown = is_teardown 

767 op.on_failure_fail_dagrun = on_failure_fail_dagrun 

768 op.downstream_task_ids = self.downstream_task_ids 

769 op.upstream_task_ids = self.upstream_task_ids 

770 return op 

771 

772 def _get_specified_expand_input(self) -> ExpandInput: 

773 """Input received from the expand call on the operator.""" 

774 return getattr(self, self._expand_input_attr) 

775 

776 def prepare_for_execution(self) -> MappedOperator: 

777 # Since a mapped operator cannot be used for execution, and an unmapped 

778 # BaseOperator needs to be created later (see render_template_fields), 

779 # we don't need to create a copy of the MappedOperator here. 

780 return self 

781 

782 # TODO (GH-52141): Do we need this in the SDK? 

783 def iter_mapped_dependencies(self) -> Iterator[AbstractOperator]: 

784 """Upstream dependencies that provide XComs used by this task for task mapping.""" 

785 from airflow.sdk.definitions.xcom_arg import XComArg 

786 

787 for operator, _ in XComArg.iter_xcom_references(self._get_specified_expand_input()): 

788 yield operator 

789 

790 def render_template_fields( 

791 self, 

792 context: Context, 

793 jinja_env: jinja2.Environment | None = None, 

794 ) -> None: 

795 """ 

796 Template all attributes listed in *self.template_fields*. 

797 

798 This updates *context* to reference the map-expanded task and relevant 

799 information, without modifying the mapped operator. The expanded task 

800 in *context* is then rendered in-place. 

801 

802 :param context: Context dict with values to apply on content. 

803 :param jinja_env: Jinja environment to use for rendering. 

804 """ 

805 from airflow.sdk.execution_time.context import context_update_for_unmapped 

806 

807 if not jinja_env: 

808 jinja_env = self.get_template_env() 

809 

810 mapped_kwargs, seen_oids = self._expand_mapped_kwargs(context) 

811 unmapped_task = self.unmap(mapped_kwargs) 

812 context_update_for_unmapped(context, unmapped_task) 

813 

814 # Since the operators that extend `BaseOperator` are not subclasses of 

815 # `MappedOperator`, we need to call `_do_render_template_fields` from 

816 # the unmapped task in order to call the operator method when we override 

817 # it to customize the parsing of nested fields. 

818 unmapped_task._do_render_template_fields( 

819 parent=unmapped_task, 

820 template_fields=self.template_fields, 

821 context=context, 

822 jinja_env=jinja_env, 

823 seen_oids=seen_oids, 

824 ) 

825 

826 def expand_start_trigger_args(self, *, context: Context) -> StartTriggerArgs | None: 

827 """ 

828 Get the kwargs to create the unmapped start_trigger_args. 

829 

830 This method is for allowing mapped operator to start execution from triggerer. 

831 """ 

832 from airflow.triggers.base import StartTriggerArgs 

833 

834 if not self.start_trigger_args: 

835 return None 

836 

837 mapped_kwargs, _ = self._expand_mapped_kwargs(context) 

838 if self._disallow_kwargs_override: 

839 prevent_duplicates( 

840 self.partial_kwargs, 

841 mapped_kwargs, 

842 fail_reason="unmappable or already specified", 

843 ) 

844 

845 # Ordering is significant; mapped kwargs should override partial ones. 

846 trigger_kwargs = mapped_kwargs.get( 

847 "trigger_kwargs", 

848 self.partial_kwargs.get("trigger_kwargs", self.start_trigger_args.trigger_kwargs), 

849 ) 

850 next_kwargs = mapped_kwargs.get( 

851 "next_kwargs", 

852 self.partial_kwargs.get("next_kwargs", self.start_trigger_args.next_kwargs), 

853 ) 

854 timeout = mapped_kwargs.get( 

855 "trigger_timeout", self.partial_kwargs.get("trigger_timeout", self.start_trigger_args.timeout) 

856 ) 

857 return StartTriggerArgs( 

858 trigger_cls=self.start_trigger_args.trigger_cls, 

859 trigger_kwargs=trigger_kwargs, 

860 next_method=self.start_trigger_args.next_method, 

861 next_kwargs=next_kwargs, 

862 timeout=timeout, 

863 )