Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/airflow/sdk/bases/operator.py: 36%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

678 statements  

1# Licensed to the Apache Software Foundation (ASF) under one 

2# or more contributor license agreements. See the NOTICE file 

3# distributed with this work for additional information 

4# regarding copyright ownership. The ASF licenses this file 

5# to you under the Apache License, Version 2.0 (the 

6# "License"); you may not use this file except in compliance 

7# with the License. You may obtain a copy of the License at 

8# 

9# http://www.apache.org/licenses/LICENSE-2.0 

10# 

11# Unless required by applicable law or agreed to in writing, 

12# software distributed under the License is distributed on an 

13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 

14# KIND, either express or implied. See the License for the 

15# specific language governing permissions and limitations 

16# under the License. 

17 

18from __future__ import annotations 

19 

20import abc 

21import collections.abc 

22import contextlib 

23import copy 

24import inspect 

25import sys 

26import warnings 

27from collections.abc import Callable, Collection, Iterable, Mapping, Sequence 

28from contextvars import ContextVar 

29from dataclasses import dataclass, field 

30from datetime import datetime, timedelta 

31from enum import Enum 

32from functools import total_ordering, wraps 

33from types import FunctionType 

34from typing import TYPE_CHECKING, Any, ClassVar, Final, NoReturn, TypeVar, cast 

35 

36import attrs 

37 

38from airflow.sdk import TriggerRule, timezone 

39from airflow.sdk._shared.secrets_masker import redact 

40from airflow.sdk.definitions._internal.abstractoperator import ( 

41 DEFAULT_EMAIL_ON_FAILURE, 

42 DEFAULT_EMAIL_ON_RETRY, 

43 DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST, 

44 DEFAULT_OWNER, 

45 DEFAULT_POOL_NAME, 

46 DEFAULT_POOL_SLOTS, 

47 DEFAULT_PRIORITY_WEIGHT, 

48 DEFAULT_QUEUE, 

49 DEFAULT_RETRIES, 

50 DEFAULT_RETRY_DELAY, 

51 DEFAULT_TASK_EXECUTION_TIMEOUT, 

52 DEFAULT_TRIGGER_RULE, 

53 DEFAULT_WAIT_FOR_PAST_DEPENDS_BEFORE_SKIPPING, 

54 DEFAULT_WEIGHT_RULE, 

55 AbstractOperator, 

56 DependencyMixin, 

57 TaskStateChangeCallback, 

58) 

59from airflow.sdk.definitions._internal.decorators import fixup_decorator_warning_stack 

60from airflow.sdk.definitions._internal.node import validate_key 

61from airflow.sdk.definitions._internal.setup_teardown import SetupTeardownContext 

62from airflow.sdk.definitions._internal.types import NOTSET, validate_instance_args 

63from airflow.sdk.definitions.edges import EdgeModifier 

64from airflow.sdk.definitions.mappedoperator import OperatorPartial, validate_mapping_kwargs 

65from airflow.sdk.definitions.param import ParamsDict 

66from airflow.sdk.exceptions import RemovedInAirflow4Warning 

67 

68# Databases do not support arbitrary precision integers, so we need to limit the range of priority weights. 

69# postgres: -2147483648 to +2147483647 (see https://www.postgresql.org/docs/current/datatype-numeric.html) 

70# mysql: -2147483648 to +2147483647 (see https://dev.mysql.com/doc/refman/8.4/en/integer-types.html) 

71# sqlite: -9223372036854775808 to +9223372036854775807 (see https://sqlite.org/datatype3.html) 

72DB_SAFE_MINIMUM = -2147483648 

73DB_SAFE_MAXIMUM = 2147483647 

74 

75 

76def db_safe_priority(priority_weight: int) -> int: 

77 """Convert priority weight to a safe value for the database.""" 

78 return max(DB_SAFE_MINIMUM, min(DB_SAFE_MAXIMUM, priority_weight)) 

79 

80 

81C = TypeVar("C", bound=Callable) 

82T = TypeVar("T", bound=FunctionType) 

83 

84if TYPE_CHECKING: 

85 from types import ClassMethodDescriptorType 

86 

87 import jinja2 

88 from typing_extensions import Self 

89 

90 from airflow.sdk.bases.operatorlink import BaseOperatorLink 

91 from airflow.sdk.definitions.context import Context 

92 from airflow.sdk.definitions.dag import DAG 

93 from airflow.sdk.definitions.operator_resources import Resources 

94 from airflow.sdk.definitions.taskgroup import TaskGroup 

95 from airflow.sdk.definitions.xcom_arg import XComArg 

96 from airflow.serialization.enums import DagAttributeTypes 

97 from airflow.task.priority_strategy import PriorityWeightStrategy 

98 from airflow.triggers.base import BaseTrigger, StartTriggerArgs 

99 

100 TaskPreExecuteHook = Callable[[Context], None] 

101 TaskPostExecuteHook = Callable[[Context, Any], None] 

102 

103__all__ = [ 

104 "BaseOperator", 

105 "chain", 

106 "chain_linear", 

107 "cross_downstream", 

108] 

109 

110 

111class TriggerFailureReason(str, Enum): 

112 """ 

113 Reasons for trigger failures. 

114 

115 Internal use only. 

116 

117 :meta private: 

118 """ 

119 

120 TRIGGER_TIMEOUT = "Trigger timeout" 

121 TRIGGER_FAILURE = "Trigger failure" 

122 

123 

124TRIGGER_FAIL_REPR = "__fail__" 

125"""String value to represent trigger failure. 

126 

127Internal use only. 

128 

129:meta private: 

130""" 

131 

132 

133def _get_parent_defaults(dag: DAG | None, task_group: TaskGroup | None) -> tuple[dict, ParamsDict]: 

134 if not dag: 

135 return {}, ParamsDict() 

136 dag_args = copy.copy(dag.default_args) 

137 dag_params = copy.deepcopy(dag.params) 

138 dag_params._fill_missing_param_source("dag") 

139 if task_group: 

140 if task_group.default_args and not isinstance(task_group.default_args, collections.abc.Mapping): 

141 raise TypeError("default_args must be a mapping") 

142 dag_args.update(task_group.default_args) 

143 return dag_args, dag_params 

144 

145 

146def get_merged_defaults( 

147 dag: DAG | None, 

148 task_group: TaskGroup | None, 

149 task_params: collections.abc.MutableMapping | None, 

150 task_default_args: dict | None, 

151) -> tuple[dict, ParamsDict]: 

152 args, params = _get_parent_defaults(dag, task_group) 

153 if task_params: 

154 if not isinstance(task_params, collections.abc.Mapping): 

155 raise TypeError(f"params must be a mapping, got {type(task_params)}") 

156 

157 task_params = ParamsDict(task_params) 

158 task_params._fill_missing_param_source("task") 

159 params.update(task_params) 

160 

161 if task_default_args: 

162 if not isinstance(task_default_args, collections.abc.Mapping): 

163 raise TypeError(f"default_args must be a mapping, got {type(task_params)}") 

164 args.update(task_default_args) 

165 with contextlib.suppress(KeyError): 

166 if params_from_default_args := ParamsDict(task_default_args["params"] or {}): 

167 params_from_default_args._fill_missing_param_source("task") 

168 params.update(params_from_default_args) 

169 

170 return args, params 

171 

172 

173def parse_retries(retries: Any) -> int | None: 

174 if retries is None: 

175 return 0 

176 if type(retries) == int: # noqa: E721 

177 return retries 

178 try: 

179 parsed_retries = int(retries) 

180 except (TypeError, ValueError): 

181 raise RuntimeError(f"'retries' type must be int, not {type(retries).__name__}") 

182 return parsed_retries 

183 

184 

185def coerce_timedelta(value: float | timedelta, *, key: str | None = None) -> timedelta: 

186 if isinstance(value, timedelta): 

187 return value 

188 return timedelta(seconds=value) 

189 

190 

191def coerce_resources(resources: dict[str, Any] | None) -> Resources | None: 

192 if resources is None: 

193 return None 

194 from airflow.sdk.definitions.operator_resources import Resources 

195 

196 return Resources(**resources) 

197 

198 

199class _PartialDescriptor: 

200 """A descriptor that guards against ``.partial`` being called on Task objects.""" 

201 

202 class_method: ClassMethodDescriptorType | None = None 

203 

204 def __get__( 

205 self, obj: BaseOperator, cls: type[BaseOperator] | None = None 

206 ) -> Callable[..., OperatorPartial]: 

207 # Call this "partial" so it looks nicer in stack traces. 

208 def partial(**kwargs): 

209 raise TypeError("partial can only be called on Operator classes, not Tasks themselves") 

210 

211 if obj is not None: 

212 return partial 

213 return self.class_method.__get__(cls, cls) 

214 

215 

216OPERATOR_DEFAULTS: dict[str, Any] = { 

217 "allow_nested_operators": True, 

218 "depends_on_past": False, 

219 "email_on_failure": DEFAULT_EMAIL_ON_FAILURE, 

220 "email_on_retry": DEFAULT_EMAIL_ON_RETRY, 

221 "execution_timeout": DEFAULT_TASK_EXECUTION_TIMEOUT, 

222 # "executor": DEFAULT_EXECUTOR, 

223 "executor_config": {}, 

224 "ignore_first_depends_on_past": DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST, 

225 "inlets": [], 

226 "map_index_template": None, 

227 "on_execute_callback": [], 

228 "on_failure_callback": [], 

229 "on_retry_callback": [], 

230 "on_skipped_callback": [], 

231 "on_success_callback": [], 

232 "outlets": [], 

233 "owner": DEFAULT_OWNER, 

234 "pool_slots": DEFAULT_POOL_SLOTS, 

235 "priority_weight": DEFAULT_PRIORITY_WEIGHT, 

236 "queue": DEFAULT_QUEUE, 

237 "retries": DEFAULT_RETRIES, 

238 "retry_delay": DEFAULT_RETRY_DELAY, 

239 "retry_exponential_backoff": 0, 

240 "trigger_rule": DEFAULT_TRIGGER_RULE, 

241 "wait_for_past_depends_before_skipping": DEFAULT_WAIT_FOR_PAST_DEPENDS_BEFORE_SKIPPING, 

242 "wait_for_downstream": False, 

243 "weight_rule": DEFAULT_WEIGHT_RULE, 

244} 

245 

246 

247# This is what handles the actual mapping. 

248 

249if TYPE_CHECKING: 

250 

251 def partial( 

252 operator_class: type[BaseOperator], 

253 *, 

254 task_id: str, 

255 dag: DAG | None = None, 

256 task_group: TaskGroup | None = None, 

257 start_date: datetime = ..., 

258 end_date: datetime = ..., 

259 owner: str = ..., 

260 email: None | str | Iterable[str] = ..., 

261 params: collections.abc.MutableMapping | None = None, 

262 resources: dict[str, Any] | None = ..., 

263 trigger_rule: str = ..., 

264 depends_on_past: bool = ..., 

265 ignore_first_depends_on_past: bool = ..., 

266 wait_for_past_depends_before_skipping: bool = ..., 

267 wait_for_downstream: bool = ..., 

268 retries: int | None = ..., 

269 queue: str = ..., 

270 pool: str = ..., 

271 pool_slots: int = ..., 

272 execution_timeout: timedelta | None = ..., 

273 max_retry_delay: None | timedelta | float = ..., 

274 retry_delay: timedelta | float = ..., 

275 retry_exponential_backoff: float = ..., 

276 priority_weight: int = ..., 

277 weight_rule: str | PriorityWeightStrategy = ..., 

278 sla: timedelta | None = ..., 

279 map_index_template: str | None = ..., 

280 max_active_tis_per_dag: int | None = ..., 

281 max_active_tis_per_dagrun: int | None = ..., 

282 on_execute_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = ..., 

283 on_failure_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = ..., 

284 on_success_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = ..., 

285 on_retry_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = ..., 

286 on_skipped_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = ..., 

287 run_as_user: str | None = ..., 

288 executor: str | None = ..., 

289 executor_config: dict | None = ..., 

290 inlets: Any | None = ..., 

291 outlets: Any | None = ..., 

292 doc: str | None = ..., 

293 doc_md: str | None = ..., 

294 doc_json: str | None = ..., 

295 doc_yaml: str | None = ..., 

296 doc_rst: str | None = ..., 

297 task_display_name: str | None = ..., 

298 logger_name: str | None = ..., 

299 allow_nested_operators: bool = True, 

300 **kwargs, 

301 ) -> OperatorPartial: ... 

302else: 

303 

304 def partial( 

305 operator_class: type[BaseOperator], 

306 *, 

307 task_id: str, 

308 dag: DAG | None = None, 

309 task_group: TaskGroup | None = None, 

310 params: collections.abc.MutableMapping | None = None, 

311 **kwargs, 

312 ): 

313 from airflow.sdk.definitions._internal.contextmanager import DagContext, TaskGroupContext 

314 

315 validate_mapping_kwargs(operator_class, "partial", kwargs) 

316 

317 dag = dag or DagContext.get_current() 

318 if dag: 

319 task_group = task_group or TaskGroupContext.get_current(dag) 

320 if task_group: 

321 task_id = task_group.child_id(task_id) 

322 

323 # Merge Dag and task group level defaults into user-supplied values. 

324 dag_default_args, partial_params = get_merged_defaults( 

325 dag=dag, 

326 task_group=task_group, 

327 task_params=params, 

328 task_default_args=kwargs.pop("default_args", None), 

329 ) 

330 

331 # Create partial_kwargs from args and kwargs 

332 partial_kwargs: dict[str, Any] = { 

333 "task_id": task_id, 

334 "dag": dag, 

335 "task_group": task_group, 

336 **kwargs, 

337 } 

338 

339 # Inject Dag-level default args into args provided to this function. 

340 # Most of the default args will be retrieved during unmapping; here we 

341 # only ensure base properties are correctly set for the scheduler. 

342 partial_kwargs.update( 

343 (k, v) 

344 for k, v in dag_default_args.items() 

345 if k not in partial_kwargs and k in BaseOperator.__init__._BaseOperatorMeta__param_names 

346 ) 

347 

348 # Fill fields not provided by the user with default values. 

349 partial_kwargs.update((k, v) for k, v in OPERATOR_DEFAULTS.items() if k not in partial_kwargs) 

350 

351 # Post-process arguments. Should be kept in sync with _TaskDecorator.expand(). 

352 if "task_concurrency" in kwargs: # Reject deprecated option. 

353 raise TypeError("unexpected argument: task_concurrency") 

354 if start_date := partial_kwargs.get("start_date", None): 

355 partial_kwargs["start_date"] = timezone.convert_to_utc(start_date) 

356 if end_date := partial_kwargs.get("end_date", None): 

357 partial_kwargs["end_date"] = timezone.convert_to_utc(end_date) 

358 if partial_kwargs["pool_slots"] < 1: 

359 dag_str = "" 

360 if dag: 

361 dag_str = f" in dag {dag.dag_id}" 

362 raise ValueError(f"pool slots for {task_id}{dag_str} cannot be less than 1") 

363 if retries := partial_kwargs.get("retries"): 

364 partial_kwargs["retries"] = BaseOperator._convert_retries(retries) 

365 partial_kwargs["retry_delay"] = BaseOperator._convert_retry_delay(partial_kwargs["retry_delay"]) 

366 partial_kwargs["max_retry_delay"] = BaseOperator._convert_max_retry_delay( 

367 partial_kwargs.get("max_retry_delay", None) 

368 ) 

369 

370 for k in ("execute", "failure", "success", "retry", "skipped"): 

371 partial_kwargs[attr] = _collect_from_input(partial_kwargs.get(attr := f"on_{k}_callback")) 

372 

373 return OperatorPartial( 

374 operator_class=operator_class, 

375 kwargs=partial_kwargs, 

376 params=partial_params, 

377 ) 

378 

379 

380class ExecutorSafeguard: 

381 """ 

382 The ExecutorSafeguard decorator. 

383 

384 Checks if the execute method of an operator isn't manually called outside 

385 the TaskInstance as we want to avoid bad mixing between decorated and 

386 classic operators. 

387 """ 

388 

389 test_mode: ClassVar[bool] = False 

390 tracker: ClassVar[ContextVar[BaseOperator]] = ContextVar("ExecutorSafeguard_sentinel") 

391 sentinel_value: ClassVar[object] = object() 

392 

393 @classmethod 

394 def decorator(cls, func): 

395 @wraps(func) 

396 def wrapper(self, *args, **kwargs): 

397 sentinel_key = f"{self.__class__.__name__}__sentinel" 

398 sentinel = kwargs.pop(sentinel_key, None) 

399 

400 with contextlib.ExitStack() as stack: 

401 if sentinel is cls.sentinel_value: 

402 token = cls.tracker.set(self) 

403 sentinel = self 

404 stack.callback(cls.tracker.reset, token) 

405 else: 

406 # No sentinel passed in, maybe the subclass execute did have it passed? 

407 sentinel = cls.tracker.get(None) 

408 

409 if not cls.test_mode and sentinel is not self: 

410 message = f"{self.__class__.__name__}.{func.__name__} cannot be called outside of the Task Runner!" 

411 if not self.allow_nested_operators: 

412 raise RuntimeError(message) 

413 self.log.warning(message) 

414 

415 # Now that we've logged, set sentinel so that `super()` calls don't log again 

416 token = cls.tracker.set(self) 

417 stack.callback(cls.tracker.reset, token) 

418 

419 return func(self, *args, **kwargs) 

420 

421 return wrapper 

422 

423 

424if "airflow.configuration" in sys.modules: 

425 # Don't try and import it if it's not already loaded 

426 from airflow.sdk.configuration import conf 

427 

428 ExecutorSafeguard.test_mode = conf.getboolean("core", "unit_test_mode") 

429 

430 

431def _collect_from_input(value_or_values: None | C | Collection[C]) -> list[C]: 

432 if not value_or_values: 

433 return [] 

434 if isinstance(value_or_values, Collection): 

435 return list(value_or_values) 

436 return [value_or_values] 

437 

438 

439class BaseOperatorMeta(abc.ABCMeta): 

440 """Metaclass of BaseOperator.""" 

441 

442 @classmethod 

443 def _apply_defaults(cls, func: T) -> T: 

444 """ 

445 Look for an argument named "default_args", and fill the unspecified arguments from it. 

446 

447 Since python2.* isn't clear about which arguments are missing when 

448 calling a function, and that this can be quite confusing with multi-level 

449 inheritance and argument defaults, this decorator also alerts with 

450 specific information about the missing arguments. 

451 """ 

452 # Cache inspect.signature for the wrapper closure to avoid calling it 

453 # at every decorated invocation. This is separate sig_cache created 

454 # per decoration, i.e. each function decorated using apply_defaults will 

455 # have a different sig_cache. 

456 sig_cache = inspect.signature(func) 

457 non_variadic_params = { 

458 name: param 

459 for (name, param) in sig_cache.parameters.items() 

460 if param.name != "self" and param.kind not in (param.VAR_POSITIONAL, param.VAR_KEYWORD) 

461 } 

462 non_optional_args = { 

463 name 

464 for name, param in non_variadic_params.items() 

465 if param.default == param.empty and name != "task_id" 

466 } 

467 

468 fixup_decorator_warning_stack(func) 

469 

470 @wraps(func) 

471 def apply_defaults(self: BaseOperator, *args: Any, **kwargs: Any) -> Any: 

472 from airflow.sdk.definitions._internal.contextmanager import DagContext, TaskGroupContext 

473 

474 if args: 

475 raise TypeError("Use keyword arguments when initializing operators") 

476 

477 instantiated_from_mapped = kwargs.pop( 

478 "_airflow_from_mapped", 

479 getattr(self, "_BaseOperator__from_mapped", False), 

480 ) 

481 

482 dag: DAG | None = kwargs.get("dag") 

483 if dag is None: 

484 dag = DagContext.get_current() 

485 if dag is not None: 

486 kwargs["dag"] = dag 

487 

488 task_group: TaskGroup | None = kwargs.get("task_group") 

489 if dag and not task_group: 

490 task_group = TaskGroupContext.get_current(dag) 

491 if task_group is not None: 

492 kwargs["task_group"] = task_group 

493 

494 default_args, merged_params = get_merged_defaults( 

495 dag=dag, 

496 task_group=task_group, 

497 task_params=kwargs.pop("params", None), 

498 task_default_args=kwargs.pop("default_args", None), 

499 ) 

500 

501 for arg in sig_cache.parameters: 

502 if arg not in kwargs and arg in default_args: 

503 kwargs[arg] = default_args[arg] 

504 

505 missing_args = non_optional_args.difference(kwargs) 

506 if len(missing_args) == 1: 

507 raise TypeError(f"missing keyword argument {missing_args.pop()!r}") 

508 if missing_args: 

509 display = ", ".join(repr(a) for a in sorted(missing_args)) 

510 raise TypeError(f"missing keyword arguments {display}") 

511 

512 if merged_params: 

513 kwargs["params"] = merged_params 

514 

515 hook = getattr(self, "_hook_apply_defaults", None) 

516 if hook: 

517 args, kwargs = hook(**kwargs, default_args=default_args) 

518 default_args = kwargs.pop("default_args", {}) 

519 

520 if not hasattr(self, "_BaseOperator__init_kwargs"): 

521 object.__setattr__(self, "_BaseOperator__init_kwargs", {}) 

522 object.__setattr__(self, "_BaseOperator__from_mapped", instantiated_from_mapped) 

523 

524 result = func(self, **kwargs, default_args=default_args) 

525 

526 # Store the args passed to init -- we need them to support task.map serialization! 

527 self._BaseOperator__init_kwargs.update(kwargs) # type: ignore 

528 

529 # Set upstream task defined by XComArgs passed to template fields of the operator. 

530 # BUT: only do this _ONCE_, not once for each class in the hierarchy 

531 if not instantiated_from_mapped and func == self.__init__.__wrapped__: # type: ignore[misc] 

532 self._set_xcomargs_dependencies() 

533 # Mark instance as instantiated so that future attr setting updates xcomarg-based deps. 

534 object.__setattr__(self, "_BaseOperator__instantiated", True) 

535 

536 return result 

537 

538 apply_defaults.__non_optional_args = non_optional_args # type: ignore 

539 apply_defaults.__param_names = set(non_variadic_params) # type: ignore 

540 

541 return cast("T", apply_defaults) 

542 

543 def __new__(cls, name, bases, namespace, **kwargs): 

544 execute_method = namespace.get("execute") 

545 if callable(execute_method) and not getattr(execute_method, "__isabstractmethod__", False): 

546 namespace["execute"] = ExecutorSafeguard.decorator(execute_method) 

547 new_cls = super().__new__(cls, name, bases, namespace, **kwargs) 

548 with contextlib.suppress(KeyError): 

549 # Update the partial descriptor with the class method, so it calls the actual function 

550 # (but let subclasses override it if they need to) 

551 partial_desc = vars(new_cls)["partial"] 

552 if isinstance(partial_desc, _PartialDescriptor): 

553 partial_desc.class_method = classmethod(partial) 

554 

555 # We patch `__init__` only if the class defines it. 

556 first_superclass = new_cls.mro()[1] 

557 if new_cls.__init__ is not first_superclass.__init__: 

558 new_cls.__init__ = cls._apply_defaults(new_cls.__init__) 

559 

560 return new_cls 

561 

562 

563# TODO: The following mapping is used to validate that the arguments passed to the BaseOperator are of the 

564# correct type. This is a temporary solution until we find a more sophisticated method for argument 

565# validation. One potential method is to use `get_type_hints` from the typing module. However, this is not 

566# fully compatible with future annotations for Python versions below 3.10. Once we require a minimum Python 

567# version that supports `get_type_hints` effectively or find a better approach, we can replace this 

568# manual type-checking method. 

569BASEOPERATOR_ARGS_EXPECTED_TYPES = { 

570 "task_id": str, 

571 "email": (str, Sequence), 

572 "email_on_retry": bool, 

573 "email_on_failure": bool, 

574 "retries": int, 

575 "retry_exponential_backoff": (int, float), 

576 "depends_on_past": bool, 

577 "ignore_first_depends_on_past": bool, 

578 "wait_for_past_depends_before_skipping": bool, 

579 "wait_for_downstream": bool, 

580 "priority_weight": int, 

581 "queue": str, 

582 "pool": str, 

583 "pool_slots": int, 

584 "trigger_rule": str, 

585 "run_as_user": str, 

586 "task_concurrency": int, 

587 "map_index_template": str, 

588 "max_active_tis_per_dag": int, 

589 "max_active_tis_per_dagrun": int, 

590 "executor": str, 

591 "do_xcom_push": bool, 

592 "multiple_outputs": bool, 

593 "doc": str, 

594 "doc_md": str, 

595 "doc_json": str, 

596 "doc_yaml": str, 

597 "doc_rst": str, 

598 "task_display_name": str, 

599 "logger_name": str, 

600 "allow_nested_operators": bool, 

601 "start_date": datetime, 

602 "end_date": datetime, 

603} 

604 

605 

606# Note: BaseOperator is defined as a dataclass, and not an attrs class as we do too much metaprogramming in 

607# here (metaclass, custom `__setattr__` behaviour) and this fights with attrs too much to make it worth it. 

608# 

609# To future reader: if you want to try and make this a "normal" attrs class, go ahead and attempt it. If you 

610# get nowhere leave your record here for the next poor soul and what problems you ran in to. 

611# 

612# @ashb, 2024/10/14 

613# - "Can't combine custom __setattr__ with on_setattr hooks" 

614# - Setting class-wide `define(on_setarrs=...)` isn't called for non-attrs subclasses 

615@total_ordering 

616@dataclass(repr=False) 

617class BaseOperator(AbstractOperator, metaclass=BaseOperatorMeta): 

618 r""" 

619 Abstract base class for all operators. 

620 

621 Since operators create objects that become nodes in the Dag, BaseOperator 

622 contains many recursive methods for Dag crawling behavior. To derive from 

623 this class, you are expected to override the constructor and the 'execute' 

624 method. 

625 

626 Operators derived from this class should perform or trigger certain tasks 

627 synchronously (wait for completion). Example of operators could be an 

628 operator that runs a Pig job (PigOperator), a sensor operator that 

629 waits for a partition to land in Hive (HiveSensorOperator), or one that 

630 moves data from Hive to MySQL (Hive2MySqlOperator). Instances of these 

631 operators (tasks) target specific operations, running specific scripts, 

632 functions or data transfers. 

633 

634 This class is abstract and shouldn't be instantiated. Instantiating a 

635 class derived from this one results in the creation of a task object, 

636 which ultimately becomes a node in Dag objects. Task dependencies should 

637 be set by using the set_upstream and/or set_downstream methods. 

638 

639 :param task_id: a unique, meaningful id for the task 

640 :param owner: the owner of the task. Using a meaningful description 

641 (e.g. user/person/team/role name) to clarify ownership is recommended. 

642 :param email: the 'to' email address(es) used in email alerts. This can be a 

643 single email or multiple ones. Multiple addresses can be specified as a 

644 comma or semicolon separated string or by passing a list of strings. (deprecated) 

645 :param email_on_retry: Indicates whether email alerts should be sent when a 

646 task is retried (deprecated) 

647 :param email_on_failure: Indicates whether email alerts should be sent when 

648 a task failed (deprecated) 

649 :param retries: the number of retries that should be performed before 

650 failing the task 

651 :param retry_delay: delay between retries, can be set as ``timedelta`` or 

652 ``float`` seconds, which will be converted into ``timedelta``, 

653 the default is ``timedelta(seconds=300)``. 

654 :param retry_exponential_backoff: multiplier for exponential backoff between retries. 

655 Set to 0 to disable (constant delay). Set to 2.0 for standard exponential backoff 

656 (delay doubles with each retry). For example, with retry_delay=4min and 

657 retry_exponential_backoff=5, retries occur after 4min, 20min, 100min, etc. 

658 :param max_retry_delay: maximum delay interval between retries, can be set as 

659 ``timedelta`` or ``float`` seconds, which will be converted into ``timedelta``. 

660 :param start_date: The ``start_date`` for the task, determines 

661 the ``logical_date`` for the first task instance. The best practice 

662 is to have the start_date rounded 

663 to your Dag's ``schedule_interval``. Daily jobs have their start_date 

664 some day at 00:00:00, hourly jobs have their start_date at 00:00 

665 of a specific hour. Note that Airflow simply looks at the latest 

666 ``logical_date`` and adds the ``schedule_interval`` to determine 

667 the next ``logical_date``. It is also very important 

668 to note that different tasks' dependencies 

669 need to line up in time. If task A depends on task B and their 

670 start_date are offset in a way that their logical_date don't line 

671 up, A's dependencies will never be met. If you are looking to delay 

672 a task, for example running a daily task at 2AM, look into the 

673 ``TimeSensor`` and ``TimeDeltaSensor``. We advise against using 

674 dynamic ``start_date`` and recommend using fixed ones. Read the 

675 FAQ entry about start_date for more information. 

676 :param end_date: if specified, the scheduler won't go beyond this date 

677 :param depends_on_past: when set to true, task instances will run 

678 sequentially and only if the previous instance has succeeded or has been skipped. 

679 The task instance for the start_date is allowed to run. 

680 :param wait_for_past_depends_before_skipping: when set to true, if the task instance 

681 should be marked as skipped, and depends_on_past is true, the ti will stay on None state 

682 waiting the task of the previous run 

683 :param wait_for_downstream: when set to true, an instance of task 

684 X will wait for tasks immediately downstream of the previous instance 

685 of task X to finish successfully or be skipped before it runs. This is useful if the 

686 different instances of a task X alter the same asset, and this asset 

687 is used by tasks downstream of task X. Note that depends_on_past 

688 is forced to True wherever wait_for_downstream is used. Also note that 

689 only tasks *immediately* downstream of the previous task instance are waited 

690 for; the statuses of any tasks further downstream are ignored. 

691 :param dag: a reference to the dag the task is attached to (if any) 

692 :param priority_weight: priority weight of this task against other task. 

693 This allows the executor to trigger higher priority tasks before 

694 others when things get backed up. Set priority_weight as a higher 

695 number for more important tasks. 

696 As not all database engines support 64-bit integers, values are capped with 32-bit. 

697 Valid range is from -2,147,483,648 to 2,147,483,647. 

698 :param weight_rule: weighting method used for the effective total 

699 priority weight of the task. Options are: 

700 ``{ downstream | upstream | absolute }`` default is ``downstream`` 

701 When set to ``downstream`` the effective weight of the task is the 

702 aggregate sum of all downstream descendants. As a result, upstream 

703 tasks will have higher weight and will be scheduled more aggressively 

704 when using positive weight values. This is useful when you have 

705 multiple dag run instances and desire to have all upstream tasks to 

706 complete for all runs before each dag can continue processing 

707 downstream tasks. When set to ``upstream`` the effective weight is the 

708 aggregate sum of all upstream ancestors. This is the opposite where 

709 downstream tasks have higher weight and will be scheduled more 

710 aggressively when using positive weight values. This is useful when you 

711 have multiple dag run instances and prefer to have each dag complete 

712 before starting upstream tasks of other dags. When set to 

713 ``absolute``, the effective weight is the exact ``priority_weight`` 

714 specified without additional weighting. You may want to do this when 

715 you know exactly what priority weight each task should have. 

716 Additionally, when set to ``absolute``, there is bonus effect of 

717 significantly speeding up the task creation process as for very large 

718 Dags. Options can be set as string or using the constants defined in 

719 the static class ``airflow.utils.WeightRule``. 

720 Irrespective of the weight rule, resulting priority values are capped with 32-bit. 

721 |experimental| 

722 Since 2.9.0, Airflow allows to define custom priority weight strategy, 

723 by creating a subclass of 

724 ``airflow.task.priority_strategy.PriorityWeightStrategy`` and registering 

725 in a plugin, then providing the class path or the class instance via 

726 ``weight_rule`` parameter. The custom priority weight strategy will be 

727 used to calculate the effective total priority weight of the task instance. 

728 :param queue: which queue to target when running this job. Not 

729 all executors implement queue management, the CeleryExecutor 

730 does support targeting specific queues. 

731 :param pool: the slot pool this task should run in, slot pools are a 

732 way to limit concurrency for certain tasks 

733 :param pool_slots: the number of pool slots this task should use (>= 1) 

734 Values less than 1 are not allowed. 

735 :param sla: DEPRECATED - The SLA feature is removed in Airflow 3.0, to be replaced with a 

736 new implementation in Airflow >=3.1. 

737 :param execution_timeout: max time allowed for the execution of 

738 this task instance, if it goes beyond it will raise and fail. 

739 :param on_failure_callback: a function or list of functions to be called when a task instance 

740 of this task fails. a context dictionary is passed as a single 

741 parameter to this function. Context contains references to related 

742 objects to the task instance and is documented under the macros 

743 section of the API. 

744 :param on_execute_callback: much like the ``on_failure_callback`` except 

745 that it is executed right before the task is executed. 

746 :param on_retry_callback: much like the ``on_failure_callback`` except 

747 that it is executed when retries occur. 

748 :param on_success_callback: much like the ``on_failure_callback`` except 

749 that it is executed when the task succeeds. 

750 :param on_skipped_callback: much like the ``on_failure_callback`` except 

751 that it is executed when skipped occur; this callback will be called only if AirflowSkipException get raised. 

752 Explicitly it is NOT called if a task is not started to be executed because of a preceding branching 

753 decision in the Dag or a trigger rule which causes execution to skip so that the task execution 

754 is never scheduled. 

755 :param pre_execute: a function to be called immediately before task 

756 execution, receiving a context dictionary; raising an exception will 

757 prevent the task from being executed. 

758 :param post_execute: a function to be called immediately after task 

759 execution, receiving a context dictionary and task result; raising an 

760 exception will prevent the task from succeeding. 

761 :param trigger_rule: defines the rule by which dependencies are applied 

762 for the task to get triggered. Options are: 

763 ``{ all_success | all_failed | all_done | all_skipped | one_success | one_done | 

764 one_failed | none_failed | none_failed_min_one_success | none_skipped | always}`` 

765 default is ``all_success``. Options can be set as string or 

766 using the constants defined in the static class 

767 ``airflow.utils.TriggerRule`` 

768 :param resources: A map of resource parameter names (the argument names of the 

769 Resources constructor) to their values. 

770 :param run_as_user: unix username to impersonate while running the task 

771 :param max_active_tis_per_dag: When set, a task will be able to limit the concurrent 

772 runs across logical_dates. 

773 :param max_active_tis_per_dagrun: When set, a task will be able to limit the concurrent 

774 task instances per Dag run. 

775 :param executor: Which executor to target when running this task. NOT YET SUPPORTED 

776 :param executor_config: Additional task-level configuration parameters that are 

777 interpreted by a specific executor. Parameters are namespaced by the name of 

778 executor. 

779 

780 **Example**: to run this task in a specific docker container through 

781 the KubernetesExecutor :: 

782 

783 MyOperator(..., executor_config={"KubernetesExecutor": {"image": "myCustomDockerImage"}}) 

784 

785 :param do_xcom_push: if True, an XCom is pushed containing the Operator's 

786 result 

787 :param multiple_outputs: if True and do_xcom_push is True, pushes multiple XComs, one for each 

788 key in the returned dictionary result. If False and do_xcom_push is True, pushes a single XCom. 

789 :param task_group: The TaskGroup to which the task should belong. This is typically provided when not 

790 using a TaskGroup as a context manager. 

791 :param doc: Add documentation or notes to your Task objects that is visible in 

792 Task Instance details View in the Webserver 

793 :param doc_md: Add documentation (in Markdown format) or notes to your Task objects 

794 that is visible in Task Instance details View in the Webserver 

795 :param doc_rst: Add documentation (in RST format) or notes to your Task objects 

796 that is visible in Task Instance details View in the Webserver 

797 :param doc_json: Add documentation (in JSON format) or notes to your Task objects 

798 that is visible in Task Instance details View in the Webserver 

799 :param doc_yaml: Add documentation (in YAML format) or notes to your Task objects 

800 that is visible in Task Instance details View in the Webserver 

801 :param task_display_name: The display name of the task which appears on the UI. 

802 :param logger_name: Name of the logger used by the Operator to emit logs. 

803 If set to `None` (default), the logger name will fall back to 

804 `airflow.task.operators.{class.__module__}.{class.__name__}` (e.g. HttpOperator will have 

805 *airflow.task.operators.airflow.providers.http.operators.http.HttpOperator* as logger). 

806 :param allow_nested_operators: if True, when an operator is executed within another one a warning message 

807 will be logged. If False, then an exception will be raised if the operator is badly used (e.g. nested 

808 within another one). In future releases of Airflow this parameter will be removed and an exception 

809 will always be thrown when operators are nested within each other (default is True). 

810 

811 **Example**: example of a bad operator mixin usage:: 

812 

813 @task(provide_context=True) 

814 def say_hello_world(**context): 

815 hello_world_task = BashOperator( 

816 task_id="hello_world_task", 

817 bash_command="python -c \"print('Hello, world!')\"", 

818 dag=dag, 

819 ) 

820 hello_world_task.execute(context) 

821 """ 

822 

823 task_id: str 

824 owner: str = DEFAULT_OWNER 

825 email: str | Sequence[str] | None = None 

826 email_on_retry: bool = DEFAULT_EMAIL_ON_RETRY 

827 email_on_failure: bool = DEFAULT_EMAIL_ON_FAILURE 

828 retries: int | None = DEFAULT_RETRIES 

829 retry_delay: timedelta = DEFAULT_RETRY_DELAY 

830 retry_exponential_backoff: float = 0 

831 max_retry_delay: timedelta | float | None = None 

832 start_date: datetime | None = None 

833 end_date: datetime | None = None 

834 depends_on_past: bool = False 

835 ignore_first_depends_on_past: bool = DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST 

836 wait_for_past_depends_before_skipping: bool = DEFAULT_WAIT_FOR_PAST_DEPENDS_BEFORE_SKIPPING 

837 wait_for_downstream: bool = False 

838 

839 # At execution_time this becomes a normal dict 

840 params: ParamsDict | dict = field(default_factory=ParamsDict) 

841 default_args: dict | None = None 

842 priority_weight: int = DEFAULT_PRIORITY_WEIGHT 

843 weight_rule: PriorityWeightStrategy | str = field(default=DEFAULT_WEIGHT_RULE) 

844 queue: str = DEFAULT_QUEUE 

845 pool: str = DEFAULT_POOL_NAME 

846 pool_slots: int = DEFAULT_POOL_SLOTS 

847 execution_timeout: timedelta | None = DEFAULT_TASK_EXECUTION_TIMEOUT 

848 on_execute_callback: Sequence[TaskStateChangeCallback] = () 

849 on_failure_callback: Sequence[TaskStateChangeCallback] = () 

850 on_success_callback: Sequence[TaskStateChangeCallback] = () 

851 on_retry_callback: Sequence[TaskStateChangeCallback] = () 

852 on_skipped_callback: Sequence[TaskStateChangeCallback] = () 

853 _pre_execute_hook: TaskPreExecuteHook | None = None 

854 _post_execute_hook: TaskPostExecuteHook | None = None 

855 trigger_rule: TriggerRule = DEFAULT_TRIGGER_RULE 

856 resources: dict[str, Any] | None = None 

857 run_as_user: str | None = None 

858 task_concurrency: int | None = None 

859 map_index_template: str | None = None 

860 max_active_tis_per_dag: int | None = None 

861 max_active_tis_per_dagrun: int | None = None 

862 executor: str | None = None 

863 executor_config: dict | None = None 

864 do_xcom_push: bool = True 

865 multiple_outputs: bool = False 

866 inlets: list[Any] = field(default_factory=list) 

867 outlets: list[Any] = field(default_factory=list) 

868 task_group: TaskGroup | None = None 

869 doc: str | None = None 

870 doc_md: str | None = None 

871 doc_json: str | None = None 

872 doc_yaml: str | None = None 

873 doc_rst: str | None = None 

874 _task_display_name: str | None = None 

875 logger_name: str | None = None 

876 allow_nested_operators: bool = True 

877 

878 is_setup: bool = False 

879 is_teardown: bool = False 

880 

881 # TODO: Task-SDK: Make these ClassVar[]? 

882 template_fields: Collection[str] = () 

883 template_ext: Sequence[str] = () 

884 

885 template_fields_renderers: ClassVar[dict[str, str]] = {} 

886 

887 operator_extra_links: Collection[BaseOperatorLink] = () 

888 

889 # Defines the color in the UI 

890 ui_color: str = "#fff" 

891 ui_fgcolor: str = "#000" 

892 

893 partial: Callable[..., OperatorPartial] = _PartialDescriptor() # type: ignore 

894 

895 _dag: DAG | None = field(init=False, default=None) 

896 

897 # Make this optional so the type matches the one define in LoggingMixin 

898 _log_config_logger_name: str | None = field(default="airflow.task.operators", init=False) 

899 _logger_name: str | None = None 

900 

901 # The _serialized_fields are lazily loaded when get_serialized_fields() method is called 

902 __serialized_fields: ClassVar[frozenset[str] | None] = None 

903 

904 _comps: ClassVar[set[str]] = { 

905 "task_id", 

906 "dag_id", 

907 "owner", 

908 "email", 

909 "email_on_retry", 

910 "retry_delay", 

911 "retry_exponential_backoff", 

912 "max_retry_delay", 

913 "start_date", 

914 "end_date", 

915 "depends_on_past", 

916 "wait_for_downstream", 

917 "priority_weight", 

918 "execution_timeout", 

919 "has_on_execute_callback", 

920 "has_on_failure_callback", 

921 "has_on_success_callback", 

922 "has_on_retry_callback", 

923 "has_on_skipped_callback", 

924 "do_xcom_push", 

925 "multiple_outputs", 

926 "allow_nested_operators", 

927 "executor", 

928 } 

929 

930 # If True, the Rendered Template fields will be overwritten in DB after execution 

931 # This is useful for Taskflow decorators that modify the template fields during execution like 

932 # @task.bash decorator. 

933 overwrite_rtif_after_execution: bool = False 

934 

935 # If True then the class constructor was called 

936 __instantiated: bool = False 

937 # List of args as passed to `init()`, after apply_defaults() has been updated. Used to "recreate" the task 

938 # when mapping 

939 # Set via the metaclass 

940 __init_kwargs: dict[str, Any] = field(init=False) 

941 

942 # Set to True before calling execute method 

943 _lock_for_execution: bool = False 

944 

945 # Set to True for an operator instantiated by a mapped operator. 

946 __from_mapped: bool = False 

947 

948 start_trigger_args: StartTriggerArgs | None = None 

949 start_from_trigger: bool = False 

950 

951 # base list which includes all the attrs that don't need deep copy. 

952 _base_operator_shallow_copy_attrs: Final[tuple[str, ...]] = ( 

953 "user_defined_macros", 

954 "user_defined_filters", 

955 "params", 

956 ) 

957 

958 # each operator should override this class attr for shallow copy attrs. 

959 shallow_copy_attrs: Sequence[str] = () 

960 

961 def __setattr__(self: BaseOperator, key: str, value: Any): 

962 if converter := getattr(self, f"_convert_{key}", None): 

963 value = converter(value) 

964 super().__setattr__(key, value) 

965 if self.__from_mapped or self._lock_for_execution: 

966 return # Skip any custom behavior for validation and during execute. 

967 if key in self.__init_kwargs: 

968 self.__init_kwargs[key] = value 

969 if self.__instantiated and key in self.template_fields: 

970 # Resolve upstreams set by assigning an XComArg after initializing 

971 # an operator, example: 

972 # op = BashOperator() 

973 # op.bash_command = "sleep 1" 

974 self._set_xcomargs_dependency(key, value) 

975 

976 def __init__( 

977 self, 

978 *, 

979 task_id: str, 

980 owner: str = DEFAULT_OWNER, 

981 email: str | Sequence[str] | None = None, 

982 email_on_retry: bool = DEFAULT_EMAIL_ON_RETRY, 

983 email_on_failure: bool = DEFAULT_EMAIL_ON_FAILURE, 

984 retries: int | None = DEFAULT_RETRIES, 

985 retry_delay: timedelta | float = DEFAULT_RETRY_DELAY, 

986 retry_exponential_backoff: float = 0, 

987 max_retry_delay: timedelta | float | None = None, 

988 start_date: datetime | None = None, 

989 end_date: datetime | None = None, 

990 depends_on_past: bool = False, 

991 ignore_first_depends_on_past: bool = DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST, 

992 wait_for_past_depends_before_skipping: bool = DEFAULT_WAIT_FOR_PAST_DEPENDS_BEFORE_SKIPPING, 

993 wait_for_downstream: bool = False, 

994 dag: DAG | None = None, 

995 params: collections.abc.MutableMapping[str, Any] | None = None, 

996 default_args: dict | None = None, 

997 priority_weight: int = DEFAULT_PRIORITY_WEIGHT, 

998 weight_rule: str | PriorityWeightStrategy = DEFAULT_WEIGHT_RULE, 

999 queue: str = DEFAULT_QUEUE, 

1000 pool: str | None = None, 

1001 pool_slots: int = DEFAULT_POOL_SLOTS, 

1002 sla: timedelta | None = None, 

1003 execution_timeout: timedelta | None = DEFAULT_TASK_EXECUTION_TIMEOUT, 

1004 on_execute_callback: None | TaskStateChangeCallback | Collection[TaskStateChangeCallback] = None, 

1005 on_failure_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = None, 

1006 on_success_callback: None | TaskStateChangeCallback | Collection[TaskStateChangeCallback] = None, 

1007 on_retry_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = None, 

1008 on_skipped_callback: None | TaskStateChangeCallback | Collection[TaskStateChangeCallback] = None, 

1009 pre_execute: TaskPreExecuteHook | None = None, 

1010 post_execute: TaskPostExecuteHook | None = None, 

1011 trigger_rule: str = DEFAULT_TRIGGER_RULE, 

1012 resources: dict[str, Any] | None = None, 

1013 run_as_user: str | None = None, 

1014 map_index_template: str | None = None, 

1015 max_active_tis_per_dag: int | None = None, 

1016 max_active_tis_per_dagrun: int | None = None, 

1017 executor: str | None = None, 

1018 executor_config: dict | None = None, 

1019 do_xcom_push: bool = True, 

1020 multiple_outputs: bool = False, 

1021 inlets: Any | None = None, 

1022 outlets: Any | None = None, 

1023 task_group: TaskGroup | None = None, 

1024 doc: str | None = None, 

1025 doc_md: str | None = None, 

1026 doc_json: str | None = None, 

1027 doc_yaml: str | None = None, 

1028 doc_rst: str | None = None, 

1029 task_display_name: str | None = None, 

1030 logger_name: str | None = None, 

1031 allow_nested_operators: bool = True, 

1032 **kwargs: Any, 

1033 ): 

1034 # Note: Metaclass handles passing in the Dag/TaskGroup from active context manager, if any 

1035 

1036 # Only apply task_group prefix if this operator was not created from a mapped operator 

1037 # Mapped operators already have the prefix applied during their creation 

1038 if task_group and not self.__from_mapped: 

1039 self.task_id = task_group.child_id(task_id) 

1040 task_group.add(self) 

1041 else: 

1042 self.task_id = task_id 

1043 

1044 super().__init__() 

1045 self.task_group = task_group 

1046 

1047 kwargs.pop("_airflow_mapped_validation_only", None) 

1048 if kwargs: 

1049 raise TypeError( 

1050 f"Invalid arguments were passed to {self.__class__.__name__} (task_id: {task_id}). " 

1051 f"Invalid arguments were:\n**kwargs: {redact(kwargs)}", 

1052 ) 

1053 validate_key(self.task_id) 

1054 

1055 self.owner = owner 

1056 self.email = email 

1057 self.email_on_retry = email_on_retry 

1058 self.email_on_failure = email_on_failure 

1059 

1060 if email is not None: 

1061 warnings.warn( 

1062 "Setting email on a task is deprecated; please migrate to SmtpNotifier.", 

1063 RemovedInAirflow4Warning, 

1064 stacklevel=2, 

1065 ) 

1066 if email and email_on_retry is not None: 

1067 warnings.warn( 

1068 "Setting email_on_retry on a task is deprecated; please migrate to SmtpNotifier.", 

1069 RemovedInAirflow4Warning, 

1070 stacklevel=2, 

1071 ) 

1072 if email and email_on_failure is not None: 

1073 warnings.warn( 

1074 "Setting email_on_failure on a task is deprecated; please migrate to SmtpNotifier.", 

1075 RemovedInAirflow4Warning, 

1076 stacklevel=2, 

1077 ) 

1078 

1079 if execution_timeout is not None and not isinstance(execution_timeout, timedelta): 

1080 raise ValueError( 

1081 f"execution_timeout must be timedelta object but passed as type: {type(execution_timeout)}" 

1082 ) 

1083 self.execution_timeout = execution_timeout 

1084 

1085 self.on_execute_callback = _collect_from_input(on_execute_callback) 

1086 self.on_failure_callback = _collect_from_input(on_failure_callback) 

1087 self.on_success_callback = _collect_from_input(on_success_callback) 

1088 self.on_retry_callback = _collect_from_input(on_retry_callback) 

1089 self.on_skipped_callback = _collect_from_input(on_skipped_callback) 

1090 self._pre_execute_hook = pre_execute 

1091 self._post_execute_hook = post_execute 

1092 

1093 self.start_date = timezone.convert_to_utc(start_date) 

1094 self.end_date = timezone.convert_to_utc(end_date) 

1095 self.executor = executor 

1096 self.executor_config = executor_config or {} 

1097 self.run_as_user = run_as_user 

1098 # TODO: 

1099 # self.retries = parse_retries(retries) 

1100 self.retries = retries 

1101 self.queue = queue 

1102 self.pool = DEFAULT_POOL_NAME if pool is None else pool 

1103 self.pool_slots = pool_slots 

1104 if self.pool_slots < 1: 

1105 dag_str = f" in dag {dag.dag_id}" if dag else "" 

1106 raise ValueError(f"pool slots for {self.task_id}{dag_str} cannot be less than 1") 

1107 if sla is not None: 

1108 warnings.warn( 

1109 "The SLA feature is removed in Airflow 3.0, replaced with Deadline Alerts in >=3.1", 

1110 stacklevel=2, 

1111 ) 

1112 

1113 try: 

1114 TriggerRule(trigger_rule) 

1115 except ValueError: 

1116 raise ValueError( 

1117 f"The trigger_rule must be one of {[rule.value for rule in TriggerRule]}," 

1118 f"'{dag.dag_id if dag else ''}.{task_id}'; received '{trigger_rule}'." 

1119 ) 

1120 

1121 self.trigger_rule: TriggerRule = TriggerRule(trigger_rule) 

1122 

1123 self.depends_on_past: bool = depends_on_past 

1124 self.ignore_first_depends_on_past: bool = ignore_first_depends_on_past 

1125 self.wait_for_past_depends_before_skipping: bool = wait_for_past_depends_before_skipping 

1126 self.wait_for_downstream: bool = wait_for_downstream 

1127 if wait_for_downstream: 

1128 self.depends_on_past = True 

1129 

1130 # Converted by setattr 

1131 self.retry_delay = retry_delay # type: ignore[assignment] 

1132 self.retry_exponential_backoff = retry_exponential_backoff 

1133 if max_retry_delay is not None: 

1134 self.max_retry_delay = max_retry_delay 

1135 

1136 self.resources = resources 

1137 

1138 self.params = ParamsDict(params) 

1139 

1140 self.priority_weight = priority_weight 

1141 self.weight_rule = weight_rule 

1142 

1143 self.max_active_tis_per_dag: int | None = max_active_tis_per_dag 

1144 self.max_active_tis_per_dagrun: int | None = max_active_tis_per_dagrun 

1145 self.do_xcom_push: bool = do_xcom_push 

1146 self.map_index_template: str | None = map_index_template 

1147 self.multiple_outputs: bool = multiple_outputs 

1148 

1149 self.doc_md = doc_md 

1150 self.doc_json = doc_json 

1151 self.doc_yaml = doc_yaml 

1152 self.doc_rst = doc_rst 

1153 self.doc = doc 

1154 

1155 self._task_display_name = task_display_name 

1156 

1157 self.allow_nested_operators = allow_nested_operators 

1158 

1159 self._logger_name = logger_name 

1160 

1161 # Lineage 

1162 self.inlets = _collect_from_input(inlets) 

1163 self.outlets = _collect_from_input(outlets) 

1164 

1165 if isinstance(self.template_fields, str): 

1166 warnings.warn( 

1167 f"The `template_fields` value for {self.task_type} is a string " 

1168 "but should be a list or tuple of string. Wrapping it in a list for execution. " 

1169 f"Please update {self.task_type} accordingly.", 

1170 UserWarning, 

1171 stacklevel=2, 

1172 ) 

1173 self.template_fields = [self.template_fields] 

1174 

1175 self.is_setup = False 

1176 self.is_teardown = False 

1177 

1178 if SetupTeardownContext.active: 

1179 SetupTeardownContext.update_context_map(self) 

1180 

1181 # We set self.dag right at the end as `_convert_dag` calls `dag.add_task` for us, and we need all the 

1182 # other properties to be set at that point 

1183 if dag is not None: 

1184 self.dag = dag 

1185 

1186 validate_instance_args(self, BASEOPERATOR_ARGS_EXPECTED_TYPES) 

1187 

1188 # Ensure priority_weight is within the valid range 

1189 self.priority_weight = db_safe_priority(self.priority_weight) 

1190 

1191 def __eq__(self, other): 

1192 if type(self) is type(other): 

1193 # Use getattr() instead of __dict__ as __dict__ doesn't return 

1194 # correct values for properties. 

1195 return all(getattr(self, c, None) == getattr(other, c, None) for c in self._comps) 

1196 return False 

1197 

1198 def __ne__(self, other): 

1199 return not self == other 

1200 

1201 def __hash__(self): 

1202 hash_components = [type(self)] 

1203 for component in self._comps: 

1204 val = getattr(self, component, None) 

1205 try: 

1206 hash(val) 

1207 hash_components.append(val) 

1208 except TypeError: 

1209 hash_components.append(repr(val)) 

1210 return hash(tuple(hash_components)) 

1211 

1212 # /Composing Operators --------------------------------------------- 

1213 

1214 def __gt__(self, other): 

1215 """ 

1216 Return [Operator] > [Outlet]. 

1217 

1218 If other is an attr annotated object it is set as an outlet of this Operator. 

1219 """ 

1220 if not isinstance(other, Iterable): 

1221 other = [other] 

1222 

1223 for obj in other: 

1224 if not attrs.has(obj): 

1225 raise TypeError(f"Left hand side ({obj}) is not an outlet") 

1226 self.add_outlets(other) 

1227 

1228 return self 

1229 

1230 def __lt__(self, other): 

1231 """ 

1232 Return [Inlet] > [Operator] or [Operator] < [Inlet]. 

1233 

1234 If other is an attr annotated object it is set as an inlet to this operator. 

1235 """ 

1236 if not isinstance(other, Iterable): 

1237 other = [other] 

1238 

1239 for obj in other: 

1240 if not attrs.has(obj): 

1241 raise TypeError(f"{obj} cannot be an inlet") 

1242 self.add_inlets(other) 

1243 

1244 return self 

1245 

1246 def __deepcopy__(self, memo: dict[int, Any]): 

1247 # Hack sorting double chained task lists by task_id to avoid hitting 

1248 # max_depth on deepcopy operations. 

1249 sys.setrecursionlimit(5000) # TODO fix this in a better way 

1250 

1251 cls = self.__class__ 

1252 result = cls.__new__(cls) 

1253 memo[id(self)] = result 

1254 

1255 shallow_copy = tuple(cls.shallow_copy_attrs) + cls._base_operator_shallow_copy_attrs 

1256 

1257 for k, v_org in self.__dict__.items(): 

1258 if k not in shallow_copy: 

1259 v = copy.deepcopy(v_org, memo) 

1260 else: 

1261 v = copy.copy(v_org) 

1262 

1263 # Bypass any setters, and set it on the object directly. This works since we are cloning ourself so 

1264 # we know the type is already fine 

1265 result.__dict__[k] = v 

1266 return result 

1267 

1268 def __getstate__(self): 

1269 state = dict(self.__dict__) 

1270 if "_log" in state: 

1271 del state["_log"] 

1272 

1273 return state 

1274 

1275 def __setstate__(self, state): 

1276 self.__dict__ = state 

1277 

1278 def add_inlets(self, inlets: Iterable[Any]): 

1279 """Set inlets to this operator.""" 

1280 self.inlets.extend(inlets) 

1281 

1282 def add_outlets(self, outlets: Iterable[Any]): 

1283 """Define the outlets of this operator.""" 

1284 self.outlets.extend(outlets) 

1285 

1286 def get_dag(self) -> DAG | None: 

1287 return self._dag 

1288 

1289 @property 

1290 def dag(self) -> DAG: 

1291 """Returns the Operator's Dag if set, otherwise raises an error.""" 

1292 if dag := self._dag: 

1293 return dag 

1294 raise RuntimeError(f"Operator {self} has not been assigned to a Dag yet") 

1295 

1296 @dag.setter 

1297 def dag(self, dag: DAG | None) -> None: 

1298 """Operators can be assigned to one Dag, one time. Repeat assignments to that same Dag are ok.""" 

1299 self._dag = dag 

1300 

1301 def _convert__dag(self, dag: DAG | None) -> DAG | None: 

1302 # Called automatically by __setattr__ method 

1303 from airflow.sdk.definitions.dag import DAG 

1304 

1305 if dag is None: 

1306 return dag 

1307 

1308 if not isinstance(dag, DAG): 

1309 raise TypeError(f"Expected dag; received {dag.__class__.__name__}") 

1310 if self._dag is not None and self._dag is not dag: 

1311 raise ValueError(f"The dag assigned to {self} can not be changed.") 

1312 

1313 if self.__from_mapped: 

1314 pass # Don't add to dag -- the mapped task takes the place. 

1315 elif dag.task_dict.get(self.task_id) is not self: 

1316 dag.add_task(self) 

1317 return dag 

1318 

1319 @staticmethod 

1320 def _convert_retries(retries: Any) -> int | None: 

1321 if retries is None: 

1322 return 0 

1323 if type(retries) == int: # noqa: E721 

1324 return retries 

1325 try: 

1326 parsed_retries = int(retries) 

1327 except (TypeError, ValueError): 

1328 raise TypeError(f"'retries' type must be int, not {type(retries).__name__}") 

1329 return parsed_retries 

1330 

1331 @staticmethod 

1332 def _convert_timedelta(value: float | timedelta | None) -> timedelta | None: 

1333 if value is None or isinstance(value, timedelta): 

1334 return value 

1335 return timedelta(seconds=value) 

1336 

1337 _convert_retry_delay = _convert_timedelta 

1338 _convert_max_retry_delay = _convert_timedelta 

1339 

1340 @staticmethod 

1341 def _convert_resources(resources: dict[str, Any] | None) -> Resources | None: 

1342 if resources is None: 

1343 return None 

1344 

1345 from airflow.sdk.definitions.operator_resources import Resources 

1346 

1347 if isinstance(resources, Resources): 

1348 return resources 

1349 

1350 return Resources(**resources) 

1351 

1352 def _convert_is_setup(self, value: bool) -> bool: 

1353 """ 

1354 Setter for is_setup property. 

1355 

1356 :meta private: 

1357 """ 

1358 if self.is_teardown and value: 

1359 raise ValueError(f"Cannot mark task '{self.task_id}' as setup; task is already a teardown.") 

1360 return value 

1361 

1362 def _convert_is_teardown(self, value: bool) -> bool: 

1363 if self.is_setup and value: 

1364 raise ValueError(f"Cannot mark task '{self.task_id}' as teardown; task is already a setup.") 

1365 return value 

1366 

1367 @property 

1368 def task_display_name(self) -> str: 

1369 return self._task_display_name or self.task_id 

1370 

1371 def has_dag(self): 

1372 """Return True if the Operator has been assigned to a Dag.""" 

1373 return self._dag is not None 

1374 

1375 def _set_xcomargs_dependencies(self) -> None: 

1376 from airflow.sdk.definitions.xcom_arg import XComArg 

1377 

1378 for f in self.template_fields: 

1379 arg = getattr(self, f, NOTSET) 

1380 if arg is not NOTSET: 

1381 XComArg.apply_upstream_relationship(self, arg) 

1382 

1383 def _set_xcomargs_dependency(self, field: str, newvalue: Any) -> None: 

1384 """ 

1385 Resolve upstream dependencies of a task. 

1386 

1387 In this way passing an ``XComArg`` as value for a template field 

1388 will result in creating upstream relation between two tasks. 

1389 

1390 **Example**: :: 

1391 

1392 with DAG(...): 

1393 generate_content = GenerateContentOperator(task_id="generate_content") 

1394 send_email = EmailOperator(..., html_content=generate_content.output) 

1395 

1396 # This is equivalent to 

1397 with DAG(...): 

1398 generate_content = GenerateContentOperator(task_id="generate_content") 

1399 send_email = EmailOperator( 

1400 ..., html_content="{{ task_instance.xcom_pull('generate_content') }}" 

1401 ) 

1402 generate_content >> send_email 

1403 

1404 """ 

1405 from airflow.sdk.definitions.xcom_arg import XComArg 

1406 

1407 if field not in self.template_fields: 

1408 return 

1409 XComArg.apply_upstream_relationship(self, newvalue) 

1410 

1411 def on_kill(self) -> None: 

1412 """ 

1413 Override this method to clean up subprocesses when a task instance gets killed. 

1414 

1415 Any use of the threading, subprocess or multiprocessing module within an 

1416 operator needs to be cleaned up, or it will leave ghost processes behind. 

1417 """ 

1418 

1419 def __repr__(self): 

1420 return f"<Task({self.task_type}): {self.task_id}>" 

1421 

1422 @property 

1423 def operator_class(self) -> type[BaseOperator]: # type: ignore[override] 

1424 return self.__class__ 

1425 

1426 @property 

1427 def task_type(self) -> str: 

1428 """@property: type of the task.""" 

1429 return self.__class__.__name__ 

1430 

1431 @property 

1432 def operator_name(self) -> str: 

1433 """@property: use a more friendly display name for the operator, if set.""" 

1434 try: 

1435 return self.custom_operator_name # type: ignore 

1436 except AttributeError: 

1437 return self.task_type 

1438 

1439 @property 

1440 def roots(self) -> list[BaseOperator]: 

1441 """Required by DAGNode.""" 

1442 return [self] 

1443 

1444 @property 

1445 def leaves(self) -> list[BaseOperator]: 

1446 """Required by DAGNode.""" 

1447 return [self] 

1448 

1449 @property 

1450 def output(self) -> XComArg: 

1451 """Returns reference to XCom pushed by current operator.""" 

1452 from airflow.sdk.definitions.xcom_arg import XComArg 

1453 

1454 return XComArg(operator=self) 

1455 

1456 @classmethod 

1457 def get_serialized_fields(cls): 

1458 """Stringified Dags and operators contain exactly these fields.""" 

1459 if not cls.__serialized_fields: 

1460 from airflow.sdk.definitions._internal.contextmanager import DagContext 

1461 

1462 # make sure the following "fake" task is not added to current active 

1463 # dag in context, otherwise, it will result in 

1464 # `RuntimeError: dictionary changed size during iteration` 

1465 # Exception in SerializedDAG.serialize_dag() call. 

1466 DagContext.push(None) 

1467 cls.__serialized_fields = frozenset( 

1468 vars(BaseOperator(task_id="test")).keys() 

1469 - { 

1470 "upstream_task_ids", 

1471 "default_args", 

1472 "dag", 

1473 "_dag", 

1474 "label", 

1475 "_BaseOperator__instantiated", 

1476 "_BaseOperator__init_kwargs", 

1477 "_BaseOperator__from_mapped", 

1478 "on_failure_fail_dagrun", 

1479 "task_group", 

1480 "_task_type", 

1481 "operator_extra_links", 

1482 "on_execute_callback", 

1483 "on_failure_callback", 

1484 "on_success_callback", 

1485 "on_retry_callback", 

1486 "on_skipped_callback", 

1487 } 

1488 | { # Class level defaults, or `@property` need to be added to this list 

1489 "start_date", 

1490 "end_date", 

1491 "task_type", 

1492 "ui_color", 

1493 "ui_fgcolor", 

1494 "template_ext", 

1495 "template_fields", 

1496 "template_fields_renderers", 

1497 "params", 

1498 "is_setup", 

1499 "is_teardown", 

1500 "on_failure_fail_dagrun", 

1501 "map_index_template", 

1502 "start_trigger_args", 

1503 "_needs_expansion", 

1504 "start_from_trigger", 

1505 "max_retry_delay", 

1506 "has_on_execute_callback", 

1507 "has_on_failure_callback", 

1508 "has_on_success_callback", 

1509 "has_on_retry_callback", 

1510 "has_on_skipped_callback", 

1511 } 

1512 ) 

1513 DagContext.pop() 

1514 

1515 return cls.__serialized_fields 

1516 

1517 def prepare_for_execution(self) -> Self: 

1518 """Lock task for execution to disable custom action in ``__setattr__`` and return a copy.""" 

1519 other = copy.copy(self) 

1520 other._lock_for_execution = True 

1521 return other 

1522 

1523 def serialize_for_task_group(self) -> tuple[DagAttributeTypes, Any]: 

1524 """Serialize; required by DAGNode.""" 

1525 from airflow.serialization.enums import DagAttributeTypes 

1526 

1527 return DagAttributeTypes.OP, self.task_id 

1528 

1529 def unmap(self, resolve: None | Mapping[str, Any]) -> Self: 

1530 """ 

1531 Get the "normal" operator from the current operator. 

1532 

1533 Since a BaseOperator is not mapped to begin with, this simply returns 

1534 the original operator. 

1535 

1536 :meta private: 

1537 """ 

1538 return self 

1539 

1540 def expand_start_trigger_args(self, *, context: Context) -> StartTriggerArgs | None: 

1541 """ 

1542 Get the start_trigger_args value of the current abstract operator. 

1543 

1544 Since a BaseOperator is not mapped to begin with, this simply returns 

1545 the original value of start_trigger_args. 

1546 

1547 :meta private: 

1548 """ 

1549 return self.start_trigger_args 

1550 

1551 def render_template_fields( 

1552 self, 

1553 context: Context, 

1554 jinja_env: jinja2.Environment | None = None, 

1555 ) -> None: 

1556 """ 

1557 Template all attributes listed in *self.template_fields*. 

1558 

1559 This mutates the attributes in-place and is irreversible. 

1560 

1561 :param context: Context dict with values to apply on content. 

1562 :param jinja_env: Jinja's environment to use for rendering. 

1563 """ 

1564 if not jinja_env: 

1565 jinja_env = self.get_template_env() 

1566 self._do_render_template_fields(self, self.template_fields, context, jinja_env, set()) 

1567 

1568 def pre_execute(self, context: Any): 

1569 """Execute right before self.execute() is called.""" 

1570 

1571 def execute(self, context: Context) -> Any: 

1572 """ 

1573 Derive when creating an operator. 

1574 

1575 The main method to execute the task. Context is the same dictionary used 

1576 as when rendering jinja templates. 

1577 

1578 Refer to get_template_context for more context. 

1579 """ 

1580 raise NotImplementedError() 

1581 

1582 def post_execute(self, context: Any, result: Any = None): 

1583 """ 

1584 Execute right after self.execute() is called. 

1585 

1586 It is passed the execution context and any results returned by the operator. 

1587 """ 

1588 

1589 def defer( 

1590 self, 

1591 *, 

1592 trigger: BaseTrigger, 

1593 method_name: str, 

1594 kwargs: dict[str, Any] | None = None, 

1595 timeout: timedelta | int | float | None = None, 

1596 ) -> NoReturn: 

1597 """ 

1598 Mark this Operator "deferred", suspending its execution until the provided trigger fires an event. 

1599 

1600 This is achieved by raising a special exception (TaskDeferred) 

1601 which is caught in the main _execute_task wrapper. Triggers can send execution back to task or end 

1602 the task instance directly. If the trigger will end the task instance itself, ``method_name`` should 

1603 be None; otherwise, provide the name of the method that should be used when resuming execution in 

1604 the task. 

1605 """ 

1606 from airflow.sdk.exceptions import TaskDeferred 

1607 

1608 raise TaskDeferred(trigger=trigger, method_name=method_name, kwargs=kwargs, timeout=timeout) 

1609 

1610 def resume_execution(self, next_method: str, next_kwargs: dict[str, Any] | None, context: Context): 

1611 """Entrypoint method called by the Task Runner (instead of execute) when this task is resumed.""" 

1612 from airflow.sdk.exceptions import TaskDeferralError, TaskDeferralTimeout 

1613 

1614 if next_kwargs is None: 

1615 next_kwargs = {} 

1616 # __fail__ is a special signal value for next_method that indicates 

1617 # this task was scheduled specifically to fail. 

1618 

1619 if next_method == TRIGGER_FAIL_REPR: 

1620 next_kwargs = next_kwargs or {} 

1621 traceback = next_kwargs.get("traceback") 

1622 if traceback is not None: 

1623 self.log.error("Trigger failed:\n%s", "\n".join(traceback)) 

1624 if (error := next_kwargs.get("error", "Unknown")) == TriggerFailureReason.TRIGGER_TIMEOUT: 

1625 raise TaskDeferralTimeout(error) 

1626 raise TaskDeferralError(error) 

1627 # Grab the callable off the Operator/Task and add in any kwargs 

1628 execute_callable = getattr(self, next_method) 

1629 return execute_callable(context, **next_kwargs) 

1630 

1631 def dry_run(self) -> None: 

1632 """Perform dry run for the operator - just render template fields.""" 

1633 self.log.info("Dry run") 

1634 for f in self.template_fields: 

1635 try: 

1636 content = getattr(self, f) 

1637 except AttributeError: 

1638 raise AttributeError( 

1639 f"{f!r} is configured as a template field " 

1640 f"but {self.task_type} does not have this attribute." 

1641 ) 

1642 

1643 if content and isinstance(content, str): 

1644 self.log.info("Rendering template for %s", f) 

1645 self.log.info(content) 

1646 

1647 @property 

1648 def has_on_execute_callback(self) -> bool: 

1649 """Return True if the task has execute callbacks.""" 

1650 return bool(self.on_execute_callback) 

1651 

1652 @property 

1653 def has_on_failure_callback(self) -> bool: 

1654 """Return True if the task has failure callbacks.""" 

1655 return bool(self.on_failure_callback) 

1656 

1657 @property 

1658 def has_on_success_callback(self) -> bool: 

1659 """Return True if the task has success callbacks.""" 

1660 return bool(self.on_success_callback) 

1661 

1662 @property 

1663 def has_on_retry_callback(self) -> bool: 

1664 """Return True if the task has retry callbacks.""" 

1665 return bool(self.on_retry_callback) 

1666 

1667 @property 

1668 def has_on_skipped_callback(self) -> bool: 

1669 """Return True if the task has skipped callbacks.""" 

1670 return bool(self.on_skipped_callback) 

1671 

1672 

1673def chain(*tasks: DependencyMixin | Sequence[DependencyMixin]) -> None: 

1674 r""" 

1675 Given a number of tasks, builds a dependency chain. 

1676 

1677 This function accepts values of BaseOperator (aka tasks), EdgeModifiers (aka Labels), XComArg, TaskGroups, 

1678 or lists containing any mix of these types (or a mix in the same list). If you want to chain between two 

1679 lists you must ensure they have the same length. 

1680 

1681 Using classic operators/sensors: 

1682 

1683 .. code-block:: python 

1684 

1685 chain(t1, [t2, t3], [t4, t5], t6) 

1686 

1687 is equivalent to:: 

1688 

1689 / -> t2 -> t4 \ 

1690 t1 -> t6 

1691 \ -> t3 -> t5 / 

1692 

1693 .. code-block:: python 

1694 

1695 t1.set_downstream(t2) 

1696 t1.set_downstream(t3) 

1697 t2.set_downstream(t4) 

1698 t3.set_downstream(t5) 

1699 t4.set_downstream(t6) 

1700 t5.set_downstream(t6) 

1701 

1702 Using task-decorated functions aka XComArgs: 

1703 

1704 .. code-block:: python 

1705 

1706 chain(x1(), [x2(), x3()], [x4(), x5()], x6()) 

1707 

1708 is equivalent to:: 

1709 

1710 / -> x2 -> x4 \ 

1711 x1 -> x6 

1712 \ -> x3 -> x5 / 

1713 

1714 .. code-block:: python 

1715 

1716 x1 = x1() 

1717 x2 = x2() 

1718 x3 = x3() 

1719 x4 = x4() 

1720 x5 = x5() 

1721 x6 = x6() 

1722 x1.set_downstream(x2) 

1723 x1.set_downstream(x3) 

1724 x2.set_downstream(x4) 

1725 x3.set_downstream(x5) 

1726 x4.set_downstream(x6) 

1727 x5.set_downstream(x6) 

1728 

1729 Using TaskGroups: 

1730 

1731 .. code-block:: python 

1732 

1733 chain(t1, task_group1, task_group2, t2) 

1734 

1735 t1.set_downstream(task_group1) 

1736 task_group1.set_downstream(task_group2) 

1737 task_group2.set_downstream(t2) 

1738 

1739 

1740 It is also possible to mix between classic operator/sensor, EdgeModifiers, XComArg, and TaskGroups: 

1741 

1742 .. code-block:: python 

1743 

1744 chain(t1, [Label("branch one"), Label("branch two")], [x1(), x2()], task_group1, x3()) 

1745 

1746 is equivalent to:: 

1747 

1748 / "branch one" -> x1 \ 

1749 t1 -> task_group1 -> x3 

1750 \ "branch two" -> x2 / 

1751 

1752 .. code-block:: python 

1753 

1754 x1 = x1() 

1755 x2 = x2() 

1756 x3 = x3() 

1757 label1 = Label("branch one") 

1758 label2 = Label("branch two") 

1759 t1.set_downstream(label1) 

1760 label1.set_downstream(x1) 

1761 t2.set_downstream(label2) 

1762 label2.set_downstream(x2) 

1763 x1.set_downstream(task_group1) 

1764 x2.set_downstream(task_group1) 

1765 task_group1.set_downstream(x3) 

1766 

1767 # or 

1768 

1769 x1 = x1() 

1770 x2 = x2() 

1771 x3 = x3() 

1772 t1.set_downstream(x1, edge_modifier=Label("branch one")) 

1773 t1.set_downstream(x2, edge_modifier=Label("branch two")) 

1774 x1.set_downstream(task_group1) 

1775 x2.set_downstream(task_group1) 

1776 task_group1.set_downstream(x3) 

1777 

1778 

1779 :param tasks: Individual and/or list of tasks, EdgeModifiers, XComArgs, or TaskGroups to set dependencies 

1780 """ 

1781 for up_task, down_task in zip(tasks, tasks[1:]): 

1782 if isinstance(up_task, DependencyMixin): 

1783 up_task.set_downstream(down_task) 

1784 continue 

1785 if isinstance(down_task, DependencyMixin): 

1786 down_task.set_upstream(up_task) 

1787 continue 

1788 if not isinstance(up_task, Sequence) or not isinstance(down_task, Sequence): 

1789 raise TypeError(f"Chain not supported between instances of {type(up_task)} and {type(down_task)}") 

1790 up_task_list = up_task 

1791 down_task_list = down_task 

1792 if len(up_task_list) != len(down_task_list): 

1793 raise ValueError( 

1794 f"Chain not supported for different length Iterable. " 

1795 f"Got {len(up_task_list)} and {len(down_task_list)}." 

1796 ) 

1797 for up_t, down_t in zip(up_task_list, down_task_list): 

1798 up_t.set_downstream(down_t) 

1799 

1800 

1801def cross_downstream( 

1802 from_tasks: Sequence[DependencyMixin], 

1803 to_tasks: DependencyMixin | Sequence[DependencyMixin], 

1804): 

1805 r""" 

1806 Set downstream dependencies for all tasks in from_tasks to all tasks in to_tasks. 

1807 

1808 Using classic operators/sensors: 

1809 

1810 .. code-block:: python 

1811 

1812 cross_downstream(from_tasks=[t1, t2, t3], to_tasks=[t4, t5, t6]) 

1813 

1814 is equivalent to:: 

1815 

1816 t1 ---> t4 

1817 \ / 

1818 t2 -X -> t5 

1819 / \ 

1820 t3 ---> t6 

1821 

1822 .. code-block:: python 

1823 

1824 t1.set_downstream(t4) 

1825 t1.set_downstream(t5) 

1826 t1.set_downstream(t6) 

1827 t2.set_downstream(t4) 

1828 t2.set_downstream(t5) 

1829 t2.set_downstream(t6) 

1830 t3.set_downstream(t4) 

1831 t3.set_downstream(t5) 

1832 t3.set_downstream(t6) 

1833 

1834 Using task-decorated functions aka XComArgs: 

1835 

1836 .. code-block:: python 

1837 

1838 cross_downstream(from_tasks=[x1(), x2(), x3()], to_tasks=[x4(), x5(), x6()]) 

1839 

1840 is equivalent to:: 

1841 

1842 x1 ---> x4 

1843 \ / 

1844 x2 -X -> x5 

1845 / \ 

1846 x3 ---> x6 

1847 

1848 .. code-block:: python 

1849 

1850 x1 = x1() 

1851 x2 = x2() 

1852 x3 = x3() 

1853 x4 = x4() 

1854 x5 = x5() 

1855 x6 = x6() 

1856 x1.set_downstream(x4) 

1857 x1.set_downstream(x5) 

1858 x1.set_downstream(x6) 

1859 x2.set_downstream(x4) 

1860 x2.set_downstream(x5) 

1861 x2.set_downstream(x6) 

1862 x3.set_downstream(x4) 

1863 x3.set_downstream(x5) 

1864 x3.set_downstream(x6) 

1865 

1866 It is also possible to mix between classic operator/sensor and XComArg tasks: 

1867 

1868 .. code-block:: python 

1869 

1870 cross_downstream(from_tasks=[t1, x2(), t3], to_tasks=[x1(), t2, x3()]) 

1871 

1872 is equivalent to:: 

1873 

1874 t1 ---> x1 

1875 \ / 

1876 x2 -X -> t2 

1877 / \ 

1878 t3 ---> x3 

1879 

1880 .. code-block:: python 

1881 

1882 x1 = x1() 

1883 x2 = x2() 

1884 x3 = x3() 

1885 t1.set_downstream(x1) 

1886 t1.set_downstream(t2) 

1887 t1.set_downstream(x3) 

1888 x2.set_downstream(x1) 

1889 x2.set_downstream(t2) 

1890 x2.set_downstream(x3) 

1891 t3.set_downstream(x1) 

1892 t3.set_downstream(t2) 

1893 t3.set_downstream(x3) 

1894 

1895 :param from_tasks: List of tasks or XComArgs to start from. 

1896 :param to_tasks: List of tasks or XComArgs to set as downstream dependencies. 

1897 """ 

1898 for task in from_tasks: 

1899 task.set_downstream(to_tasks) 

1900 

1901 

1902def chain_linear(*elements: DependencyMixin | Sequence[DependencyMixin]): 

1903 """ 

1904 Simplify task dependency definition. 

1905 

1906 E.g.: suppose you want precedence like so:: 

1907 

1908 ╭─op2─╮ ╭─op4─╮ 

1909 op1─┤ ├─├─op5─┤─op7 

1910 ╰-op3─╯ ╰-op6─╯ 

1911 

1912 Then you can accomplish like so:: 

1913 

1914 chain_linear(op1, [op2, op3], [op4, op5, op6], op7) 

1915 

1916 :param elements: a list of operators / lists of operators 

1917 """ 

1918 if not elements: 

1919 raise ValueError("No tasks provided; nothing to do.") 

1920 prev_elem = None 

1921 deps_set = False 

1922 for curr_elem in elements: 

1923 if isinstance(curr_elem, EdgeModifier): 

1924 raise ValueError("Labels are not supported by chain_linear") 

1925 if prev_elem is not None: 

1926 for task in prev_elem: 

1927 task >> curr_elem 

1928 if not deps_set: 

1929 deps_set = True 

1930 prev_elem = [curr_elem] if isinstance(curr_elem, DependencyMixin) else curr_elem 

1931 if not deps_set: 

1932 raise ValueError("No dependencies were set. Did you forget to expand with `*`?")