Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/airflow/sdk/bases/decorator.py: 28%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

332 statements  

1# Licensed to the Apache Software Foundation (ASF) under one 

2# or more contributor license agreements. See the NOTICE file 

3# distributed with this work for additional information 

4# regarding copyright ownership. The ASF licenses this file 

5# to you under the Apache License, Version 2.0 (the 

6# "License"); you may not use this file except in compliance 

7# with the License. You may obtain a copy of the License at 

8# 

9# http://www.apache.org/licenses/LICENSE-2.0 

10# 

11# Unless required by applicable law or agreed to in writing, 

12# software distributed under the License is distributed on an 

13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 

14# KIND, either express or implied. See the License for the 

15# specific language governing permissions and limitations 

16# under the License. 

17from __future__ import annotations 

18 

19import inspect 

20import itertools 

21import re 

22import textwrap 

23import warnings 

24from collections.abc import Callable, Collection, Iterator, Mapping, Sequence 

25from contextlib import suppress 

26from functools import cached_property, partial, update_wrapper 

27from typing import TYPE_CHECKING, Any, ClassVar, Generic, ParamSpec, Protocol, TypeVar, cast, overload 

28 

29import attr 

30import typing_extensions 

31 

32from airflow.sdk import TriggerRule, timezone 

33from airflow.sdk.bases.operator import ( 

34 BaseOperator, 

35 coerce_resources, 

36 coerce_timedelta, 

37 get_merged_defaults, 

38 parse_retries, 

39) 

40from airflow.sdk.definitions._internal.contextmanager import DagContext, TaskGroupContext 

41from airflow.sdk.definitions._internal.decorators import remove_task_decorator 

42from airflow.sdk.definitions._internal.expandinput import ( 

43 EXPAND_INPUT_EMPTY, 

44 DictOfListsExpandInput, 

45 ListOfDictsExpandInput, 

46 is_mappable, 

47) 

48from airflow.sdk.definitions._internal.types import NOTSET 

49from airflow.sdk.definitions.asset import Asset 

50from airflow.sdk.definitions.context import KNOWN_CONTEXT_KEYS 

51from airflow.sdk.definitions.mappedoperator import ( 

52 MappedOperator, 

53 ensure_xcomarg_return_value, 

54 prevent_duplicates, 

55) 

56from airflow.sdk.definitions.xcom_arg import XComArg 

57 

58if TYPE_CHECKING: 

59 from airflow.sdk.definitions._internal.expandinput import ( 

60 ExpandInput, 

61 OperatorExpandArgument, 

62 OperatorExpandKwargsArgument, 

63 ) 

64 from airflow.sdk.definitions.context import Context 

65 from airflow.sdk.definitions.dag import DAG 

66 from airflow.sdk.definitions.mappedoperator import ValidationSource 

67 from airflow.sdk.definitions.taskgroup import TaskGroup 

68 

69 

70class ExpandableFactory(Protocol): 

71 """ 

72 Protocol providing inspection against wrapped function. 

73 

74 This is used in ``validate_expand_kwargs`` and implemented by function 

75 decorators like ``@task`` and ``@task_group``. 

76 

77 :meta private: 

78 """ 

79 

80 function: Callable 

81 

82 @cached_property 

83 def function_signature(self) -> inspect.Signature: 

84 return inspect.signature(self.function) 

85 

86 @cached_property 

87 def _mappable_function_argument_names(self) -> set[str]: 

88 """Arguments that can be mapped against.""" 

89 return set(self.function_signature.parameters) 

90 

91 def _validate_arg_names(self, func: ValidationSource, kwargs: dict[str, Any]) -> None: 

92 """Ensure that all arguments passed to operator-mapping functions are accounted for.""" 

93 parameters = self.function_signature.parameters 

94 if any(v.kind == inspect.Parameter.VAR_KEYWORD for v in parameters.values()): 

95 return 

96 kwargs_left = kwargs.copy() 

97 for arg_name in self._mappable_function_argument_names: 

98 value = kwargs_left.pop(arg_name, NOTSET) 

99 if func == "expand" and value is not NOTSET and not is_mappable(value): 

100 tname = type(value).__name__ 

101 raise ValueError( 

102 f"expand() got an unexpected type {tname!r} for keyword argument {arg_name!r}" 

103 ) 

104 if len(kwargs_left) == 1: 

105 raise TypeError(f"{func}() got an unexpected keyword argument {next(iter(kwargs_left))!r}") 

106 if kwargs_left: 

107 names = ", ".join(repr(n) for n in kwargs_left) 

108 raise TypeError(f"{func}() got unexpected keyword arguments {names}") 

109 

110 

111def get_unique_task_id( 

112 task_id: str, 

113 dag: DAG | None = None, 

114 task_group: TaskGroup | None = None, 

115) -> str: 

116 """ 

117 Generate unique task id given a Dag (or if run in a Dag context). 

118 

119 IDs are generated by appending a unique number to the end of 

120 the original task id. 

121 

122 Example: 

123 task_id 

124 task_id__1 

125 task_id__2 

126 ... 

127 task_id__20 

128 """ 

129 dag = dag or DagContext.get_current() 

130 if not dag: 

131 return task_id 

132 

133 # We need to check if we are in the context of TaskGroup as the task_id may 

134 # already be altered 

135 task_group = task_group or TaskGroupContext.get_current(dag) 

136 tg_task_id = task_group.child_id(task_id) if task_group else task_id 

137 

138 if tg_task_id not in dag.task_ids: 

139 return task_id 

140 

141 def _find_id_suffixes(dag: DAG) -> Iterator[int]: 

142 prefix = re.split(r"__\d+$", tg_task_id)[0] 

143 for task_id in dag.task_ids: 

144 match = re.match(rf"^{prefix}__(\d+)$", task_id) 

145 if match: 

146 yield int(match.group(1)) 

147 yield 0 # Default if there's no matching task ID. 

148 

149 core = re.split(r"__\d+$", task_id)[0] 

150 return f"{core}__{max(_find_id_suffixes(dag)) + 1}" 

151 

152 

153def unwrap_partial(fn: Callable) -> Callable: 

154 while isinstance(fn, partial): 

155 fn = fn.func 

156 return fn 

157 

158 

159def unwrap_callable(func): 

160 from airflow.sdk.definitions.mappedoperator import OperatorPartial 

161 

162 if isinstance(func, (_TaskDecorator, OperatorPartial)): 

163 func = getattr(func, "function", getattr(func, "_func", func)) 

164 

165 func = unwrap_partial(func) 

166 

167 with suppress(Exception): 

168 func = inspect.unwrap(func) 

169 

170 return func 

171 

172 

173def is_async_callable(func): 

174 """Detect if a callable (possibly wrapped) is an async function.""" 

175 func = unwrap_callable(func) 

176 

177 if not callable(func): 

178 return False 

179 

180 # Direct async function 

181 if inspect.iscoroutinefunction(func): 

182 return True 

183 

184 # Callable object with async __call__ 

185 if not inspect.isfunction(func): 

186 call = type(func).__call__ # Bandit-safe 

187 with suppress(Exception): 

188 call = inspect.unwrap(call) 

189 if inspect.iscoroutinefunction(call): 

190 return True 

191 

192 return False 

193 

194 

195class DecoratedOperator(BaseOperator): 

196 """ 

197 Wraps a Python callable and captures args/kwargs when called for execution. 

198 

199 :param python_callable: A reference to an object that is callable 

200 :param op_kwargs: a dictionary of keyword arguments that will get unpacked 

201 in your function (templated) 

202 :param op_args: a list of positional arguments that will get unpacked when 

203 calling your callable (templated) 

204 :param multiple_outputs: If set to True, the decorated function's return value will be unrolled to 

205 multiple XCom values. Dict will unroll to XCom values with its keys as XCom keys. Defaults to False. 

206 :param kwargs_to_upstream: For certain operators, we might need to upstream certain arguments 

207 that would otherwise be absorbed by the DecoratedOperator (for example python_callable for the 

208 PythonOperator). This gives a user the option to upstream kwargs as needed. 

209 """ 

210 

211 template_fields: Sequence[str] = ("op_args", "op_kwargs") 

212 template_fields_renderers = {"op_args": "py", "op_kwargs": "py"} 

213 

214 # since we won't mutate the arguments, we should just do the shallow copy 

215 # there are some cases we can't deepcopy the objects (e.g protobuf). 

216 shallow_copy_attrs: Sequence[str] = ("python_callable",) 

217 

218 def __init__( 

219 self, 

220 *, 

221 python_callable: Callable, 

222 task_id: str, 

223 op_args: Collection[Any] | None = None, 

224 op_kwargs: Mapping[str, Any] | None = None, 

225 kwargs_to_upstream: dict[str, Any] | None = None, 

226 **kwargs, 

227 ) -> None: 

228 if not getattr(self, "_BaseOperator__from_mapped", False): 

229 # If we are being created from calling unmap(), then don't mangle the task id 

230 task_id = get_unique_task_id(task_id, kwargs.get("dag"), kwargs.get("task_group")) 

231 self.python_callable = python_callable 

232 kwargs_to_upstream = kwargs_to_upstream or {} 

233 op_args = op_args or [] 

234 op_kwargs = op_kwargs or {} 

235 

236 # Check the decorated function's signature. We go through the argument 

237 # list and "fill in" defaults to arguments that are known context keys, 

238 # since values for those will be provided when the task is run. Since 

239 # we're not actually running the function, None is good enough here. 

240 signature = inspect.signature(python_callable) 

241 

242 # Don't allow context argument defaults other than None to avoid ambiguities. 

243 faulty_parameters = [ 

244 param.name 

245 for param in signature.parameters.values() 

246 if param.name in KNOWN_CONTEXT_KEYS and param.default not in (None, inspect.Parameter.empty) 

247 ] 

248 if faulty_parameters: 

249 message = f"Context key parameter {faulty_parameters[0]} can't have a default other than None" 

250 raise ValueError(message) 

251 

252 parameters = [ 

253 param.replace(default=None) if param.name in KNOWN_CONTEXT_KEYS else param 

254 for param in signature.parameters.values() 

255 ] 

256 try: 

257 signature = signature.replace(parameters=parameters) 

258 except ValueError as err: 

259 message = textwrap.dedent( 

260 f""" 

261 The function signature broke while assigning defaults to context key parameters. 

262 

263 The decorator is replacing the signature 

264 > {python_callable.__name__}({", ".join(str(param) for param in signature.parameters.values())}) 

265 

266 with 

267 > {python_callable.__name__}({", ".join(str(param) for param in parameters)}) 

268 

269 which isn't valid: {err} 

270 """ 

271 ) 

272 raise ValueError(message) from err 

273 

274 # Check that arguments can be binded. There's a slight difference when 

275 # we do validation for task-mapping: Since there's no guarantee we can 

276 # receive enough arguments at parse time, we use bind_partial to simply 

277 # check all the arguments we know are valid. Whether these are enough 

278 # can only be known at execution time, when unmapping happens, and this 

279 # is called without the _airflow_mapped_validation_only flag. 

280 if kwargs.get("_airflow_mapped_validation_only"): 

281 signature.bind_partial(*op_args, **op_kwargs) 

282 else: 

283 signature.bind(*op_args, **op_kwargs) 

284 

285 self.op_args = op_args 

286 self.op_kwargs = op_kwargs 

287 super().__init__(task_id=task_id, **kwargs_to_upstream, **kwargs) 

288 

289 @property 

290 def is_async(self) -> bool: 

291 return is_async_callable(self.python_callable) 

292 

293 def execute(self, context: Context): 

294 # todo make this more generic (move to prepare_lineage) so it deals with non taskflow operators 

295 # as well 

296 for arg in itertools.chain(self.op_args, self.op_kwargs.values()): 

297 if isinstance(arg, Asset): 

298 self.inlets.append(arg) 

299 return_value = super().execute(context) 

300 return self._handle_output(return_value=return_value) 

301 

302 def _handle_output(self, return_value: Any): 

303 """ 

304 Handle logic for whether a decorator needs to push a single return value or multiple return values. 

305 

306 It sets outlets if any assets are found in the returned value(s) 

307 

308 :param return_value: 

309 :param context: 

310 :param xcom_push: 

311 """ 

312 if isinstance(return_value, Asset): 

313 self.outlets.append(return_value) 

314 if isinstance(return_value, list): 

315 for item in return_value: 

316 if isinstance(item, Asset): 

317 self.outlets.append(item) 

318 return return_value 

319 

320 def _hook_apply_defaults(self, *args, **kwargs): 

321 if "python_callable" not in kwargs: 

322 return args, kwargs 

323 

324 python_callable = kwargs["python_callable"] 

325 default_args = kwargs.get("default_args") or {} 

326 op_kwargs = kwargs.get("op_kwargs") or {} 

327 f_sig = inspect.signature(python_callable) 

328 for arg in f_sig.parameters: 

329 if arg not in op_kwargs and arg in default_args: 

330 op_kwargs[arg] = default_args[arg] 

331 kwargs["op_kwargs"] = op_kwargs 

332 return args, kwargs 

333 

334 def get_python_source(self): 

335 raw_source = inspect.getsource(self.python_callable) 

336 raw_source_lines = [line for line in raw_source.splitlines() if not line.strip().startswith("#")] 

337 res = textwrap.dedent("\n".join(raw_source_lines)) + "\n" 

338 res = remove_task_decorator(res, self.custom_operator_name) 

339 return res 

340 

341 

342FParams = ParamSpec("FParams") 

343 

344FReturn = TypeVar("FReturn") 

345 

346OperatorSubclass = TypeVar("OperatorSubclass", bound="BaseOperator") 

347 

348 

349@attr.define(slots=False) 

350class _TaskDecorator(ExpandableFactory, Generic[FParams, FReturn, OperatorSubclass]): 

351 """ 

352 Helper class for providing dynamic task mapping to decorated functions. 

353 

354 ``task_decorator_factory`` returns an instance of this, instead of just a plain wrapped function. 

355 

356 :meta private: 

357 """ 

358 

359 function: Callable[FParams, FReturn] = attr.ib(validator=attr.validators.is_callable()) 

360 operator_class: type[OperatorSubclass] 

361 multiple_outputs: bool = attr.ib() 

362 kwargs: dict[str, Any] = attr.ib(factory=dict) 

363 

364 decorator_name: str = attr.ib(repr=False, default="task") 

365 

366 _airflow_is_task_decorator: ClassVar[bool] = True 

367 is_setup: bool = False 

368 is_teardown: bool = False 

369 on_failure_fail_dagrun: bool = False 

370 

371 # This is set in __attrs_post_init__ by update_wrapper. Provided here for type hints. 

372 __wrapped__: Callable[FParams, FReturn] = attr.ib(init=False) 

373 

374 @multiple_outputs.default 

375 def _infer_multiple_outputs(self): 

376 if "return" not in self.function.__annotations__: 

377 # No return type annotation, nothing to infer 

378 return False 

379 

380 try: 

381 # We only care about the return annotation, not anything about the parameters 

382 def fake(): ... 

383 

384 fake.__annotations__ = {"return": self.function.__annotations__["return"]} 

385 

386 return_type = typing_extensions.get_type_hints(fake, self.function.__globals__).get("return", Any) 

387 except NameError as e: 

388 warnings.warn( 

389 f"Cannot infer multiple_outputs for TaskFlow function {self.function.__name__!r} with forward" 

390 f" type references that are not imported. (Error was {e})", 

391 stacklevel=4, 

392 ) 

393 return False 

394 except TypeError: # Can't evaluate return type. 

395 return False 

396 ttype = getattr(return_type, "__origin__", return_type) 

397 return isinstance(ttype, type) and issubclass(ttype, Mapping) 

398 

399 def __attrs_post_init__(self): 

400 if "self" in self.function_signature.parameters: 

401 raise TypeError(f"@{self.decorator_name} does not support methods") 

402 self.kwargs.setdefault("task_id", self.function.__name__) 

403 update_wrapper(self, self.function) 

404 

405 def __call__(self, *args: FParams.args, **kwargs: FParams.kwargs) -> XComArg: 

406 if self.is_teardown: 

407 if "trigger_rule" in self.kwargs: 

408 raise ValueError("Trigger rule not configurable for teardown tasks.") 

409 self.kwargs.update(trigger_rule=TriggerRule.ALL_DONE_SETUP_SUCCESS) 

410 on_failure_fail_dagrun = self.kwargs.pop("on_failure_fail_dagrun", self.on_failure_fail_dagrun) 

411 op = self.operator_class( 

412 python_callable=self.function, 

413 op_args=args, 

414 op_kwargs=kwargs, 

415 multiple_outputs=self.multiple_outputs, 

416 **self.kwargs, 

417 ) 

418 op.is_setup = self.is_setup 

419 op.is_teardown = self.is_teardown 

420 op.on_failure_fail_dagrun = on_failure_fail_dagrun 

421 op_doc_attrs = [op.doc, op.doc_json, op.doc_md, op.doc_rst, op.doc_yaml] 

422 # Set the task's doc_md to the function's docstring if it exists and no other doc* args are set. 

423 if self.function.__doc__ and not any(op_doc_attrs): 

424 op.doc_md = self.function.__doc__ 

425 return XComArg(op) 

426 

427 def _validate_arg_names(self, func: ValidationSource, kwargs: dict[str, Any]): 

428 # Ensure that context variables are not shadowed. 

429 context_keys_being_mapped = KNOWN_CONTEXT_KEYS.intersection(kwargs) 

430 if len(context_keys_being_mapped) == 1: 

431 (name,) = context_keys_being_mapped 

432 raise ValueError(f"cannot call {func}() on task context variable {name!r}") 

433 if context_keys_being_mapped: 

434 names = ", ".join(repr(n) for n in context_keys_being_mapped) 

435 raise ValueError(f"cannot call {func}() on task context variables {names}") 

436 

437 super()._validate_arg_names(func, kwargs) 

438 

439 def expand(self, **map_kwargs: OperatorExpandArgument) -> XComArg: 

440 if self.kwargs.get("trigger_rule") == TriggerRule.ALWAYS and any( 

441 [isinstance(expanded, XComArg) for expanded in map_kwargs.values()] 

442 ): 

443 raise ValueError( 

444 "Task-generated mapping within a task using 'expand' is not allowed with trigger rule 'always'." 

445 ) 

446 if not map_kwargs: 

447 raise TypeError("no arguments to expand against") 

448 self._validate_arg_names("expand", map_kwargs) 

449 prevent_duplicates(self.kwargs, map_kwargs, fail_reason="mapping already partial") 

450 # Since the input is already checked at parse time, we can set strict 

451 # to False to skip the checks on execution. 

452 if self.is_teardown: 

453 if "trigger_rule" in self.kwargs: 

454 raise ValueError("Trigger rule not configurable for teardown tasks.") 

455 self.kwargs.update(trigger_rule=TriggerRule.ALL_DONE_SETUP_SUCCESS) 

456 return self._expand(DictOfListsExpandInput(map_kwargs), strict=False) 

457 

458 def expand_kwargs(self, kwargs: OperatorExpandKwargsArgument, *, strict: bool = True) -> XComArg: 

459 if ( 

460 self.kwargs.get("trigger_rule") == TriggerRule.ALWAYS 

461 and not isinstance(kwargs, XComArg) 

462 and any( 

463 [ 

464 isinstance(v, XComArg) 

465 for kwarg in kwargs 

466 if not isinstance(kwarg, XComArg) 

467 for v in kwarg.values() 

468 ] 

469 ) 

470 ): 

471 raise ValueError( 

472 "Task-generated mapping within a task using 'expand_kwargs' is not allowed with trigger rule 'always'." 

473 ) 

474 if isinstance(kwargs, Sequence): 

475 for item in kwargs: 

476 if not isinstance(item, (XComArg, Mapping)): 

477 raise TypeError(f"expected XComArg or list[dict], not {type(kwargs).__name__}") 

478 elif not isinstance(kwargs, XComArg): 

479 raise TypeError(f"expected XComArg or list[dict], not {type(kwargs).__name__}") 

480 return self._expand(ListOfDictsExpandInput(kwargs), strict=strict) 

481 

482 def _expand(self, expand_input: ExpandInput, *, strict: bool) -> XComArg: 

483 ensure_xcomarg_return_value(expand_input.value) 

484 

485 task_kwargs = self.kwargs.copy() 

486 dag = task_kwargs.pop("dag", None) or DagContext.get_current() 

487 task_group = task_kwargs.pop("task_group", None) or TaskGroupContext.get_current(dag) 

488 

489 default_args, partial_params = get_merged_defaults( 

490 dag=dag, 

491 task_group=task_group, 

492 task_params=task_kwargs.pop("params", None), 

493 task_default_args=task_kwargs.pop("default_args", None), 

494 ) 

495 partial_kwargs: dict[str, Any] = { 

496 "is_setup": self.is_setup, 

497 "is_teardown": self.is_teardown, 

498 "on_failure_fail_dagrun": self.on_failure_fail_dagrun, 

499 } 

500 base_signature = inspect.signature(BaseOperator) 

501 ignore = { 

502 "default_args", # This is target we are working on now. 

503 "kwargs", # A common name for a keyword argument. 

504 "do_xcom_push", # In the same boat as `multiple_outputs` 

505 "multiple_outputs", # We will use `self.multiple_outputs` instead. 

506 "params", # Already handled above `partial_params`. 

507 "task_concurrency", # Deprecated(replaced by `max_active_tis_per_dag`). 

508 } 

509 partial_keys = set(base_signature.parameters) - ignore 

510 partial_kwargs.update({key: value for key, value in default_args.items() if key in partial_keys}) 

511 partial_kwargs.update(task_kwargs) 

512 

513 task_id = get_unique_task_id(partial_kwargs.pop("task_id"), dag, task_group) 

514 if task_group: 

515 task_id = task_group.child_id(task_id) 

516 

517 # Logic here should be kept in sync with BaseOperatorMeta.partial(). 

518 if partial_kwargs.get("wait_for_downstream"): 

519 partial_kwargs["depends_on_past"] = True 

520 start_date = timezone.convert_to_utc(partial_kwargs.pop("start_date", None)) 

521 end_date = timezone.convert_to_utc(partial_kwargs.pop("end_date", None)) 

522 if "pool_slots" in partial_kwargs: 

523 if partial_kwargs["pool_slots"] < 1: 

524 dag_str = "" 

525 if dag: 

526 dag_str = f" in dag {dag.dag_id}" 

527 raise ValueError(f"pool slots for {task_id}{dag_str} cannot be less than 1") 

528 

529 for fld, convert in ( 

530 ("retries", parse_retries), 

531 ("retry_delay", coerce_timedelta), 

532 ("max_retry_delay", coerce_timedelta), 

533 ("resources", coerce_resources), 

534 ): 

535 if (v := partial_kwargs.get(fld, NOTSET)) is not NOTSET: 

536 partial_kwargs[fld] = convert(v) 

537 

538 partial_kwargs.setdefault("executor_config", {}) 

539 partial_kwargs.setdefault("op_args", []) 

540 partial_kwargs.setdefault("op_kwargs", {}) 

541 

542 # Mypy does not work well with a subclassed attrs class :( 

543 _MappedOperator = cast("Any", DecoratedMappedOperator) 

544 

545 try: 

546 operator_name = self.operator_class.custom_operator_name # type: ignore 

547 except AttributeError: 

548 operator_name = self.operator_class.__name__ 

549 

550 operator = _MappedOperator( 

551 operator_class=self.operator_class, 

552 expand_input=EXPAND_INPUT_EMPTY, # Don't use this; mapped values go to op_kwargs_expand_input. 

553 partial_kwargs=partial_kwargs, 

554 task_id=task_id, 

555 params=partial_params, 

556 operator_extra_links=self.operator_class.operator_extra_links, 

557 template_ext=self.operator_class.template_ext, 

558 template_fields=self.operator_class.template_fields, 

559 template_fields_renderers=self.operator_class.template_fields_renderers, 

560 ui_color=self.operator_class.ui_color, 

561 ui_fgcolor=self.operator_class.ui_fgcolor, 

562 is_empty=False, 

563 is_sensor=self.operator_class._is_sensor, 

564 can_skip_downstream=self.operator_class._can_skip_downstream, 

565 task_module=self.operator_class.__module__, 

566 task_type=self.operator_class.__name__, 

567 operator_name=operator_name, 

568 dag=dag, 

569 task_group=task_group, 

570 start_date=start_date, 

571 end_date=end_date, 

572 multiple_outputs=self.multiple_outputs, 

573 python_callable=self.function, 

574 op_kwargs_expand_input=expand_input, 

575 disallow_kwargs_override=strict, 

576 # Different from classic operators, kwargs passed to a taskflow 

577 # task's expand() contribute to the op_kwargs operator argument, not 

578 # the operator arguments themselves, and should expand against it. 

579 expand_input_attr="op_kwargs_expand_input", 

580 start_trigger_args=self.operator_class.start_trigger_args, 

581 start_from_trigger=self.operator_class.start_from_trigger, 

582 ) 

583 return XComArg(operator=operator) 

584 

585 def partial(self, **kwargs: Any) -> _TaskDecorator[FParams, FReturn, OperatorSubclass]: 

586 self._validate_arg_names("partial", kwargs) 

587 old_kwargs = self.kwargs.get("op_kwargs", {}) 

588 prevent_duplicates(old_kwargs, kwargs, fail_reason="duplicate partial") 

589 kwargs.update(old_kwargs) 

590 return attr.evolve(self, kwargs={**self.kwargs, "op_kwargs": kwargs}) 

591 

592 def override(self, **kwargs: Any) -> _TaskDecorator[FParams, FReturn, OperatorSubclass]: 

593 result = attr.evolve(self, kwargs={**self.kwargs, **kwargs}) 

594 setattr(result, "is_setup", self.is_setup) 

595 setattr(result, "is_teardown", self.is_teardown) 

596 setattr(result, "on_failure_fail_dagrun", self.on_failure_fail_dagrun) 

597 return result 

598 

599 

600@attr.define(kw_only=True, repr=False) 

601class DecoratedMappedOperator(MappedOperator): 

602 """MappedOperator implementation for @task-decorated task function.""" 

603 

604 multiple_outputs: bool 

605 python_callable: Callable 

606 

607 # We can't save these in expand_input because op_kwargs need to be present 

608 # in partial_kwargs, and MappedOperator prevents duplication. 

609 op_kwargs_expand_input: ExpandInput 

610 

611 def __hash__(self): 

612 return id(self) 

613 

614 def _expand_mapped_kwargs(self, context: Mapping[str, Any]) -> tuple[Mapping[str, Any], set[int]]: 

615 # We only use op_kwargs_expand_input so this must always be empty. 

616 if self.expand_input is not EXPAND_INPUT_EMPTY: 

617 raise AssertionError(f"unexpected expand_input: {self.expand_input}") 

618 op_kwargs, resolved_oids = super()._expand_mapped_kwargs(context) 

619 return {"op_kwargs": op_kwargs}, resolved_oids 

620 

621 def _get_unmap_kwargs(self, mapped_kwargs: Mapping[str, Any], *, strict: bool) -> dict[str, Any]: 

622 partial_op_kwargs = self.partial_kwargs["op_kwargs"] 

623 mapped_op_kwargs = mapped_kwargs["op_kwargs"] 

624 

625 if strict: 

626 prevent_duplicates(partial_op_kwargs, mapped_op_kwargs, fail_reason="mapping already partial") 

627 

628 kwargs = { 

629 "multiple_outputs": self.multiple_outputs, 

630 "python_callable": self.python_callable, 

631 "op_kwargs": {**partial_op_kwargs, **mapped_op_kwargs}, 

632 } 

633 return super()._get_unmap_kwargs(kwargs, strict=False) 

634 

635 

636class Task(Protocol, Generic[FParams, FReturn]): 

637 """ 

638 Declaration of a @task-decorated callable for type-checking. 

639 

640 An instance of this type inherits the call signature of the decorated 

641 function wrapped in it (not *exactly* since it actually returns an XComArg, 

642 but there's no way to express that right now), and provides two additional 

643 methods for task-mapping. 

644 

645 This type is implemented by ``_TaskDecorator`` at runtime. 

646 """ 

647 

648 __call__: Callable[FParams, XComArg] 

649 

650 function: Callable[FParams, FReturn] 

651 

652 @property 

653 def __wrapped__(self) -> Callable[FParams, FReturn]: ... 

654 

655 def partial(self, **kwargs: Any) -> Task[FParams, FReturn]: ... 

656 

657 def expand(self, **kwargs: OperatorExpandArgument) -> XComArg: ... 

658 

659 def expand_kwargs(self, kwargs: OperatorExpandKwargsArgument, *, strict: bool = True) -> XComArg: ... 

660 

661 def override(self, **kwargs: Any) -> Task[FParams, FReturn]: ... 

662 

663 

664class TaskDecorator(Protocol): 

665 """Type declaration for ``task_decorator_factory`` return type.""" 

666 

667 @overload 

668 def __call__( # type: ignore[misc] 

669 self, 

670 python_callable: Callable[FParams, FReturn], 

671 ) -> Task[FParams, FReturn]: 

672 """For the "bare decorator" ``@task`` case.""" 

673 

674 @overload 

675 def __call__( 

676 self, 

677 *, 

678 multiple_outputs: bool | None = None, 

679 **kwargs: Any, 

680 ) -> Callable[[Callable[FParams, FReturn]], Task[FParams, FReturn]]: 

681 """For the decorator factory ``@task()`` case.""" 

682 

683 def override(self, **kwargs: Any) -> Task[FParams, FReturn]: ... 

684 

685 

686def task_decorator_factory( 

687 python_callable: Callable | None = None, 

688 *, 

689 multiple_outputs: bool | None = None, 

690 decorated_operator_class: type[BaseOperator], 

691 **kwargs, 

692) -> TaskDecorator: 

693 """ 

694 Generate a wrapper that wraps a function into an Airflow operator. 

695 

696 Can be reused in a single Dag. 

697 

698 :param python_callable: Function to decorate. 

699 :param multiple_outputs: If set to True, the decorated function's return 

700 value will be unrolled to multiple XCom values. Dict will unroll to XCom 

701 values with its keys as XCom keys. If set to False (default), only at 

702 most one XCom value is pushed. 

703 :param decorated_operator_class: The operator that executes the logic needed 

704 to run the python function in the correct environment. 

705 

706 Other kwargs are directly forwarded to the underlying operator class when 

707 it's instantiated. 

708 """ 

709 if multiple_outputs is None: 

710 multiple_outputs = cast("bool", attr.NOTHING) 

711 if python_callable: 

712 decorator = _TaskDecorator( 

713 function=python_callable, 

714 multiple_outputs=multiple_outputs, 

715 operator_class=decorated_operator_class, 

716 kwargs=kwargs, 

717 ) 

718 return cast("TaskDecorator", decorator) 

719 if python_callable is not None: 

720 raise TypeError("No args allowed while using @task, use kwargs instead") 

721 

722 def decorator_factory(python_callable): 

723 return _TaskDecorator( 

724 function=python_callable, 

725 multiple_outputs=multiple_outputs, 

726 operator_class=decorated_operator_class, 

727 kwargs=kwargs, 

728 ) 

729 

730 return cast("TaskDecorator", decorator_factory)