Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/airflow/sdk/bases/operator.py: 36%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

679 statements  

1# Licensed to the Apache Software Foundation (ASF) under one 

2# or more contributor license agreements. See the NOTICE file 

3# distributed with this work for additional information 

4# regarding copyright ownership. The ASF licenses this file 

5# to you under the Apache License, Version 2.0 (the 

6# "License"); you may not use this file except in compliance 

7# with the License. You may obtain a copy of the License at 

8# 

9# http://www.apache.org/licenses/LICENSE-2.0 

10# 

11# Unless required by applicable law or agreed to in writing, 

12# software distributed under the License is distributed on an 

13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 

14# KIND, either express or implied. See the License for the 

15# specific language governing permissions and limitations 

16# under the License. 

17 

18from __future__ import annotations 

19 

20import abc 

21import collections.abc 

22import contextlib 

23import copy 

24import inspect 

25import sys 

26import warnings 

27from collections.abc import Callable, Collection, Iterable, Mapping, Sequence 

28from contextvars import ContextVar 

29from dataclasses import dataclass, field 

30from datetime import datetime, timedelta 

31from enum import Enum 

32from functools import total_ordering, wraps 

33from types import FunctionType 

34from typing import TYPE_CHECKING, Any, ClassVar, Final, NoReturn, TypeVar, cast 

35 

36import attrs 

37 

38from airflow.sdk import TriggerRule, timezone 

39from airflow.sdk._shared.secrets_masker import redact 

40from airflow.sdk.definitions._internal.abstractoperator import ( 

41 DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST, 

42 DEFAULT_OWNER, 

43 DEFAULT_POOL_NAME, 

44 DEFAULT_POOL_SLOTS, 

45 DEFAULT_PRIORITY_WEIGHT, 

46 DEFAULT_QUEUE, 

47 DEFAULT_RETRIES, 

48 DEFAULT_RETRY_DELAY, 

49 DEFAULT_TASK_EXECUTION_TIMEOUT, 

50 DEFAULT_TRIGGER_RULE, 

51 DEFAULT_WAIT_FOR_PAST_DEPENDS_BEFORE_SKIPPING, 

52 DEFAULT_WEIGHT_RULE, 

53 AbstractOperator, 

54 DependencyMixin, 

55 TaskStateChangeCallback, 

56) 

57from airflow.sdk.definitions._internal.decorators import fixup_decorator_warning_stack 

58from airflow.sdk.definitions._internal.node import validate_key 

59from airflow.sdk.definitions._internal.setup_teardown import SetupTeardownContext 

60from airflow.sdk.definitions._internal.types import NOTSET, validate_instance_args 

61from airflow.sdk.definitions.edges import EdgeModifier 

62from airflow.sdk.definitions.mappedoperator import OperatorPartial, validate_mapping_kwargs 

63from airflow.sdk.definitions.param import ParamsDict 

64from airflow.sdk.exceptions import RemovedInAirflow4Warning 

65from airflow.task.priority_strategy import ( 

66 PriorityWeightStrategy, 

67 airflow_priority_weight_strategies, 

68 validate_and_load_priority_weight_strategy, 

69) 

70 

71# Databases do not support arbitrary precision integers, so we need to limit the range of priority weights. 

72# postgres: -2147483648 to +2147483647 (see https://www.postgresql.org/docs/current/datatype-numeric.html) 

73# mysql: -2147483648 to +2147483647 (see https://dev.mysql.com/doc/refman/8.4/en/integer-types.html) 

74# sqlite: -9223372036854775808 to +9223372036854775807 (see https://sqlite.org/datatype3.html) 

75DB_SAFE_MINIMUM = -2147483648 

76DB_SAFE_MAXIMUM = 2147483647 

77 

78 

79def db_safe_priority(priority_weight: int) -> int: 

80 """Convert priority weight to a safe value for the database.""" 

81 return max(DB_SAFE_MINIMUM, min(DB_SAFE_MAXIMUM, priority_weight)) 

82 

83 

84C = TypeVar("C", bound=Callable) 

85T = TypeVar("T", bound=FunctionType) 

86 

87if TYPE_CHECKING: 

88 from types import ClassMethodDescriptorType 

89 

90 import jinja2 

91 from typing_extensions import Self 

92 

93 from airflow.sdk.bases.operatorlink import BaseOperatorLink 

94 from airflow.sdk.definitions.context import Context 

95 from airflow.sdk.definitions.dag import DAG 

96 from airflow.sdk.definitions.operator_resources import Resources 

97 from airflow.sdk.definitions.taskgroup import TaskGroup 

98 from airflow.sdk.definitions.xcom_arg import XComArg 

99 from airflow.serialization.enums import DagAttributeTypes 

100 from airflow.task.priority_strategy import PriorityWeightStrategy 

101 from airflow.triggers.base import BaseTrigger, StartTriggerArgs 

102 

103 TaskPreExecuteHook = Callable[[Context], None] 

104 TaskPostExecuteHook = Callable[[Context, Any], None] 

105 

106__all__ = [ 

107 "BaseOperator", 

108 "chain", 

109 "chain_linear", 

110 "cross_downstream", 

111] 

112 

113 

114class TriggerFailureReason(str, Enum): 

115 """ 

116 Reasons for trigger failures. 

117 

118 Internal use only. 

119 

120 :meta private: 

121 """ 

122 

123 TRIGGER_TIMEOUT = "Trigger timeout" 

124 TRIGGER_FAILURE = "Trigger failure" 

125 

126 

127TRIGGER_FAIL_REPR = "__fail__" 

128"""String value to represent trigger failure. 

129 

130Internal use only. 

131 

132:meta private: 

133""" 

134 

135 

136def _get_parent_defaults(dag: DAG | None, task_group: TaskGroup | None) -> tuple[dict, ParamsDict]: 

137 if not dag: 

138 return {}, ParamsDict() 

139 dag_args = copy.copy(dag.default_args) 

140 dag_params = copy.deepcopy(dag.params) 

141 dag_params._fill_missing_param_source("dag") 

142 if task_group: 

143 if task_group.default_args and not isinstance(task_group.default_args, collections.abc.Mapping): 

144 raise TypeError("default_args must be a mapping") 

145 dag_args.update(task_group.default_args) 

146 return dag_args, dag_params 

147 

148 

149def get_merged_defaults( 

150 dag: DAG | None, 

151 task_group: TaskGroup | None, 

152 task_params: collections.abc.MutableMapping | None, 

153 task_default_args: dict | None, 

154) -> tuple[dict, ParamsDict]: 

155 args, params = _get_parent_defaults(dag, task_group) 

156 if task_params: 

157 if not isinstance(task_params, collections.abc.Mapping): 

158 raise TypeError(f"params must be a mapping, got {type(task_params)}") 

159 

160 task_params = ParamsDict(task_params) 

161 task_params._fill_missing_param_source("task") 

162 params.update(task_params) 

163 

164 if task_default_args: 

165 if not isinstance(task_default_args, collections.abc.Mapping): 

166 raise TypeError(f"default_args must be a mapping, got {type(task_params)}") 

167 args.update(task_default_args) 

168 with contextlib.suppress(KeyError): 

169 if params_from_default_args := ParamsDict(task_default_args["params"] or {}): 

170 params_from_default_args._fill_missing_param_source("task") 

171 params.update(params_from_default_args) 

172 

173 return args, params 

174 

175 

176def parse_retries(retries: Any) -> int | None: 

177 if retries is None: 

178 return 0 

179 if type(retries) == int: # noqa: E721 

180 return retries 

181 try: 

182 parsed_retries = int(retries) 

183 except (TypeError, ValueError): 

184 raise RuntimeError(f"'retries' type must be int, not {type(retries).__name__}") 

185 return parsed_retries 

186 

187 

188def coerce_timedelta(value: float | timedelta, *, key: str | None = None) -> timedelta: 

189 if isinstance(value, timedelta): 

190 return value 

191 return timedelta(seconds=value) 

192 

193 

194def coerce_resources(resources: dict[str, Any] | None) -> Resources | None: 

195 if resources is None: 

196 return None 

197 from airflow.sdk.definitions.operator_resources import Resources 

198 

199 return Resources(**resources) 

200 

201 

202class _PartialDescriptor: 

203 """A descriptor that guards against ``.partial`` being called on Task objects.""" 

204 

205 class_method: ClassMethodDescriptorType | None = None 

206 

207 def __get__( 

208 self, obj: BaseOperator, cls: type[BaseOperator] | None = None 

209 ) -> Callable[..., OperatorPartial]: 

210 # Call this "partial" so it looks nicer in stack traces. 

211 def partial(**kwargs): 

212 raise TypeError("partial can only be called on Operator classes, not Tasks themselves") 

213 

214 if obj is not None: 

215 return partial 

216 return self.class_method.__get__(cls, cls) 

217 

218 

219OPERATOR_DEFAULTS: dict[str, Any] = { 

220 "allow_nested_operators": True, 

221 "depends_on_past": False, 

222 "email_on_failure": True, 

223 "email_on_retry": True, 

224 "execution_timeout": DEFAULT_TASK_EXECUTION_TIMEOUT, 

225 # "executor": DEFAULT_EXECUTOR, 

226 "executor_config": {}, 

227 "ignore_first_depends_on_past": DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST, 

228 "inlets": [], 

229 "map_index_template": None, 

230 "on_execute_callback": [], 

231 "on_failure_callback": [], 

232 "on_retry_callback": [], 

233 "on_skipped_callback": [], 

234 "on_success_callback": [], 

235 "outlets": [], 

236 "owner": DEFAULT_OWNER, 

237 "pool_slots": DEFAULT_POOL_SLOTS, 

238 "priority_weight": DEFAULT_PRIORITY_WEIGHT, 

239 "queue": DEFAULT_QUEUE, 

240 "retries": DEFAULT_RETRIES, 

241 "retry_delay": DEFAULT_RETRY_DELAY, 

242 "retry_exponential_backoff": 0, 

243 "trigger_rule": DEFAULT_TRIGGER_RULE, 

244 "wait_for_past_depends_before_skipping": DEFAULT_WAIT_FOR_PAST_DEPENDS_BEFORE_SKIPPING, 

245 "wait_for_downstream": False, 

246 "weight_rule": DEFAULT_WEIGHT_RULE, 

247} 

248 

249 

250# This is what handles the actual mapping. 

251 

252if TYPE_CHECKING: 

253 

254 def partial( 

255 operator_class: type[BaseOperator], 

256 *, 

257 task_id: str, 

258 dag: DAG | None = None, 

259 task_group: TaskGroup | None = None, 

260 start_date: datetime = ..., 

261 end_date: datetime = ..., 

262 owner: str = ..., 

263 email: None | str | Iterable[str] = ..., 

264 params: collections.abc.MutableMapping | None = None, 

265 resources: dict[str, Any] | None = ..., 

266 trigger_rule: str = ..., 

267 depends_on_past: bool = ..., 

268 ignore_first_depends_on_past: bool = ..., 

269 wait_for_past_depends_before_skipping: bool = ..., 

270 wait_for_downstream: bool = ..., 

271 retries: int | None = ..., 

272 queue: str = ..., 

273 pool: str = ..., 

274 pool_slots: int = ..., 

275 execution_timeout: timedelta | None = ..., 

276 max_retry_delay: None | timedelta | float = ..., 

277 retry_delay: timedelta | float = ..., 

278 retry_exponential_backoff: float = ..., 

279 priority_weight: int = ..., 

280 weight_rule: str | PriorityWeightStrategy = ..., 

281 sla: timedelta | None = ..., 

282 map_index_template: str | None = ..., 

283 max_active_tis_per_dag: int | None = ..., 

284 max_active_tis_per_dagrun: int | None = ..., 

285 on_execute_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = ..., 

286 on_failure_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = ..., 

287 on_success_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = ..., 

288 on_retry_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = ..., 

289 on_skipped_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = ..., 

290 run_as_user: str | None = ..., 

291 executor: str | None = ..., 

292 executor_config: dict | None = ..., 

293 inlets: Any | None = ..., 

294 outlets: Any | None = ..., 

295 doc: str | None = ..., 

296 doc_md: str | None = ..., 

297 doc_json: str | None = ..., 

298 doc_yaml: str | None = ..., 

299 doc_rst: str | None = ..., 

300 task_display_name: str | None = ..., 

301 logger_name: str | None = ..., 

302 allow_nested_operators: bool = True, 

303 **kwargs, 

304 ) -> OperatorPartial: ... 

305else: 

306 

307 def partial( 

308 operator_class: type[BaseOperator], 

309 *, 

310 task_id: str, 

311 dag: DAG | None = None, 

312 task_group: TaskGroup | None = None, 

313 params: collections.abc.MutableMapping | None = None, 

314 **kwargs, 

315 ): 

316 from airflow.sdk.definitions._internal.contextmanager import DagContext, TaskGroupContext 

317 

318 validate_mapping_kwargs(operator_class, "partial", kwargs) 

319 

320 dag = dag or DagContext.get_current() 

321 if dag: 

322 task_group = task_group or TaskGroupContext.get_current(dag) 

323 if task_group: 

324 task_id = task_group.child_id(task_id) 

325 

326 # Merge Dag and task group level defaults into user-supplied values. 

327 dag_default_args, partial_params = get_merged_defaults( 

328 dag=dag, 

329 task_group=task_group, 

330 task_params=params, 

331 task_default_args=kwargs.pop("default_args", None), 

332 ) 

333 

334 # Create partial_kwargs from args and kwargs 

335 partial_kwargs: dict[str, Any] = { 

336 "task_id": task_id, 

337 "dag": dag, 

338 "task_group": task_group, 

339 **kwargs, 

340 } 

341 

342 # Inject Dag-level default args into args provided to this function. 

343 # Most of the default args will be retrieved during unmapping; here we 

344 # only ensure base properties are correctly set for the scheduler. 

345 partial_kwargs.update( 

346 (k, v) 

347 for k, v in dag_default_args.items() 

348 if k not in partial_kwargs and k in BaseOperator.__init__._BaseOperatorMeta__param_names 

349 ) 

350 

351 # Fill fields not provided by the user with default values. 

352 partial_kwargs.update((k, v) for k, v in OPERATOR_DEFAULTS.items() if k not in partial_kwargs) 

353 

354 # Post-process arguments. Should be kept in sync with _TaskDecorator.expand(). 

355 if "task_concurrency" in kwargs: # Reject deprecated option. 

356 raise TypeError("unexpected argument: task_concurrency") 

357 if start_date := partial_kwargs.get("start_date", None): 

358 partial_kwargs["start_date"] = timezone.convert_to_utc(start_date) 

359 if end_date := partial_kwargs.get("end_date", None): 

360 partial_kwargs["end_date"] = timezone.convert_to_utc(end_date) 

361 if partial_kwargs["pool_slots"] < 1: 

362 dag_str = "" 

363 if dag: 

364 dag_str = f" in dag {dag.dag_id}" 

365 raise ValueError(f"pool slots for {task_id}{dag_str} cannot be less than 1") 

366 if retries := partial_kwargs.get("retries"): 

367 partial_kwargs["retries"] = BaseOperator._convert_retries(retries) 

368 partial_kwargs["retry_delay"] = BaseOperator._convert_retry_delay(partial_kwargs["retry_delay"]) 

369 partial_kwargs["max_retry_delay"] = BaseOperator._convert_max_retry_delay( 

370 partial_kwargs.get("max_retry_delay", None) 

371 ) 

372 

373 for k in ("execute", "failure", "success", "retry", "skipped"): 

374 partial_kwargs[attr] = _collect_from_input(partial_kwargs.get(attr := f"on_{k}_callback")) 

375 

376 return OperatorPartial( 

377 operator_class=operator_class, 

378 kwargs=partial_kwargs, 

379 params=partial_params, 

380 ) 

381 

382 

383class ExecutorSafeguard: 

384 """ 

385 The ExecutorSafeguard decorator. 

386 

387 Checks if the execute method of an operator isn't manually called outside 

388 the TaskInstance as we want to avoid bad mixing between decorated and 

389 classic operators. 

390 """ 

391 

392 test_mode: ClassVar[bool] = False 

393 tracker: ClassVar[ContextVar[BaseOperator]] = ContextVar("ExecutorSafeguard_sentinel") 

394 sentinel_value: ClassVar[object] = object() 

395 

396 @classmethod 

397 def decorator(cls, func): 

398 @wraps(func) 

399 def wrapper(self, *args, **kwargs): 

400 sentinel_key = f"{self.__class__.__name__}__sentinel" 

401 sentinel = kwargs.pop(sentinel_key, None) 

402 

403 with contextlib.ExitStack() as stack: 

404 if sentinel is cls.sentinel_value: 

405 token = cls.tracker.set(self) 

406 sentinel = self 

407 stack.callback(cls.tracker.reset, token) 

408 else: 

409 # No sentinel passed in, maybe the subclass execute did have it passed? 

410 sentinel = cls.tracker.get(None) 

411 

412 if not cls.test_mode and sentinel is not self: 

413 message = f"{self.__class__.__name__}.{func.__name__} cannot be called outside of the Task Runner!" 

414 if not self.allow_nested_operators: 

415 raise RuntimeError(message) 

416 self.log.warning(message) 

417 

418 # Now that we've logged, set sentinel so that `super()` calls don't log again 

419 token = cls.tracker.set(self) 

420 stack.callback(cls.tracker.reset, token) 

421 

422 return func(self, *args, **kwargs) 

423 

424 return wrapper 

425 

426 

427if "airflow.configuration" in sys.modules: 

428 # Don't try and import it if it's not already loaded 

429 from airflow.sdk.configuration import conf 

430 

431 ExecutorSafeguard.test_mode = conf.getboolean("core", "unit_test_mode") 

432 

433 

434def _collect_from_input(value_or_values: None | C | Collection[C]) -> list[C]: 

435 if not value_or_values: 

436 return [] 

437 if isinstance(value_or_values, Collection): 

438 return list(value_or_values) 

439 return [value_or_values] 

440 

441 

442class BaseOperatorMeta(abc.ABCMeta): 

443 """Metaclass of BaseOperator.""" 

444 

445 @classmethod 

446 def _apply_defaults(cls, func: T) -> T: 

447 """ 

448 Look for an argument named "default_args", and fill the unspecified arguments from it. 

449 

450 Since python2.* isn't clear about which arguments are missing when 

451 calling a function, and that this can be quite confusing with multi-level 

452 inheritance and argument defaults, this decorator also alerts with 

453 specific information about the missing arguments. 

454 """ 

455 # Cache inspect.signature for the wrapper closure to avoid calling it 

456 # at every decorated invocation. This is separate sig_cache created 

457 # per decoration, i.e. each function decorated using apply_defaults will 

458 # have a different sig_cache. 

459 sig_cache = inspect.signature(func) 

460 non_variadic_params = { 

461 name: param 

462 for (name, param) in sig_cache.parameters.items() 

463 if param.name != "self" and param.kind not in (param.VAR_POSITIONAL, param.VAR_KEYWORD) 

464 } 

465 non_optional_args = { 

466 name 

467 for name, param in non_variadic_params.items() 

468 if param.default == param.empty and name != "task_id" 

469 } 

470 

471 fixup_decorator_warning_stack(func) 

472 

473 @wraps(func) 

474 def apply_defaults(self: BaseOperator, *args: Any, **kwargs: Any) -> Any: 

475 from airflow.sdk.definitions._internal.contextmanager import DagContext, TaskGroupContext 

476 

477 if args: 

478 raise TypeError("Use keyword arguments when initializing operators") 

479 

480 instantiated_from_mapped = kwargs.pop( 

481 "_airflow_from_mapped", 

482 getattr(self, "_BaseOperator__from_mapped", False), 

483 ) 

484 

485 dag: DAG | None = kwargs.get("dag") 

486 if dag is None: 

487 dag = DagContext.get_current() 

488 if dag is not None: 

489 kwargs["dag"] = dag 

490 

491 task_group: TaskGroup | None = kwargs.get("task_group") 

492 if dag and not task_group: 

493 task_group = TaskGroupContext.get_current(dag) 

494 if task_group is not None: 

495 kwargs["task_group"] = task_group 

496 

497 default_args, merged_params = get_merged_defaults( 

498 dag=dag, 

499 task_group=task_group, 

500 task_params=kwargs.pop("params", None), 

501 task_default_args=kwargs.pop("default_args", None), 

502 ) 

503 

504 for arg in sig_cache.parameters: 

505 if arg not in kwargs and arg in default_args: 

506 kwargs[arg] = default_args[arg] 

507 

508 missing_args = non_optional_args.difference(kwargs) 

509 if len(missing_args) == 1: 

510 raise TypeError(f"missing keyword argument {missing_args.pop()!r}") 

511 if missing_args: 

512 display = ", ".join(repr(a) for a in sorted(missing_args)) 

513 raise TypeError(f"missing keyword arguments {display}") 

514 

515 if merged_params: 

516 kwargs["params"] = merged_params 

517 

518 hook = getattr(self, "_hook_apply_defaults", None) 

519 if hook: 

520 args, kwargs = hook(**kwargs, default_args=default_args) 

521 default_args = kwargs.pop("default_args", {}) 

522 

523 if not hasattr(self, "_BaseOperator__init_kwargs"): 

524 object.__setattr__(self, "_BaseOperator__init_kwargs", {}) 

525 object.__setattr__(self, "_BaseOperator__from_mapped", instantiated_from_mapped) 

526 

527 result = func(self, **kwargs, default_args=default_args) 

528 

529 # Store the args passed to init -- we need them to support task.map serialization! 

530 self._BaseOperator__init_kwargs.update(kwargs) # type: ignore 

531 

532 # Set upstream task defined by XComArgs passed to template fields of the operator. 

533 # BUT: only do this _ONCE_, not once for each class in the hierarchy 

534 if not instantiated_from_mapped and func == self.__init__.__wrapped__: # type: ignore[misc] 

535 self._set_xcomargs_dependencies() 

536 # Mark instance as instantiated so that future attr setting updates xcomarg-based deps. 

537 object.__setattr__(self, "_BaseOperator__instantiated", True) 

538 

539 return result 

540 

541 apply_defaults.__non_optional_args = non_optional_args # type: ignore 

542 apply_defaults.__param_names = set(non_variadic_params) # type: ignore 

543 

544 return cast("T", apply_defaults) 

545 

546 def __new__(cls, name, bases, namespace, **kwargs): 

547 execute_method = namespace.get("execute") 

548 if callable(execute_method) and not getattr(execute_method, "__isabstractmethod__", False): 

549 namespace["execute"] = ExecutorSafeguard.decorator(execute_method) 

550 new_cls = super().__new__(cls, name, bases, namespace, **kwargs) 

551 with contextlib.suppress(KeyError): 

552 # Update the partial descriptor with the class method, so it calls the actual function 

553 # (but let subclasses override it if they need to) 

554 partial_desc = vars(new_cls)["partial"] 

555 if isinstance(partial_desc, _PartialDescriptor): 

556 partial_desc.class_method = classmethod(partial) 

557 

558 # We patch `__init__` only if the class defines it. 

559 first_superclass = new_cls.mro()[1] 

560 if new_cls.__init__ is not first_superclass.__init__: 

561 new_cls.__init__ = cls._apply_defaults(new_cls.__init__) 

562 

563 return new_cls 

564 

565 

566# TODO: The following mapping is used to validate that the arguments passed to the BaseOperator are of the 

567# correct type. This is a temporary solution until we find a more sophisticated method for argument 

568# validation. One potential method is to use `get_type_hints` from the typing module. However, this is not 

569# fully compatible with future annotations for Python versions below 3.10. Once we require a minimum Python 

570# version that supports `get_type_hints` effectively or find a better approach, we can replace this 

571# manual type-checking method. 

572BASEOPERATOR_ARGS_EXPECTED_TYPES = { 

573 "task_id": str, 

574 "email": (str, Sequence), 

575 "email_on_retry": bool, 

576 "email_on_failure": bool, 

577 "retries": int, 

578 "retry_exponential_backoff": (int, float), 

579 "depends_on_past": bool, 

580 "ignore_first_depends_on_past": bool, 

581 "wait_for_past_depends_before_skipping": bool, 

582 "wait_for_downstream": bool, 

583 "priority_weight": int, 

584 "queue": str, 

585 "pool": str, 

586 "pool_slots": int, 

587 "trigger_rule": str, 

588 "run_as_user": str, 

589 "task_concurrency": int, 

590 "map_index_template": str, 

591 "max_active_tis_per_dag": int, 

592 "max_active_tis_per_dagrun": int, 

593 "executor": str, 

594 "do_xcom_push": bool, 

595 "multiple_outputs": bool, 

596 "doc": str, 

597 "doc_md": str, 

598 "doc_json": str, 

599 "doc_yaml": str, 

600 "doc_rst": str, 

601 "task_display_name": str, 

602 "logger_name": str, 

603 "allow_nested_operators": bool, 

604 "start_date": datetime, 

605 "end_date": datetime, 

606} 

607 

608 

609# Note: BaseOperator is defined as a dataclass, and not an attrs class as we do too much metaprogramming in 

610# here (metaclass, custom `__setattr__` behaviour) and this fights with attrs too much to make it worth it. 

611# 

612# To future reader: if you want to try and make this a "normal" attrs class, go ahead and attempt it. If you 

613# get nowhere leave your record here for the next poor soul and what problems you ran in to. 

614# 

615# @ashb, 2024/10/14 

616# - "Can't combine custom __setattr__ with on_setattr hooks" 

617# - Setting class-wide `define(on_setarrs=...)` isn't called for non-attrs subclasses 

618@total_ordering 

619@dataclass(repr=False) 

620class BaseOperator(AbstractOperator, metaclass=BaseOperatorMeta): 

621 r""" 

622 Abstract base class for all operators. 

623 

624 Since operators create objects that become nodes in the Dag, BaseOperator 

625 contains many recursive methods for Dag crawling behavior. To derive from 

626 this class, you are expected to override the constructor and the 'execute' 

627 method. 

628 

629 Operators derived from this class should perform or trigger certain tasks 

630 synchronously (wait for completion). Example of operators could be an 

631 operator that runs a Pig job (PigOperator), a sensor operator that 

632 waits for a partition to land in Hive (HiveSensorOperator), or one that 

633 moves data from Hive to MySQL (Hive2MySqlOperator). Instances of these 

634 operators (tasks) target specific operations, running specific scripts, 

635 functions or data transfers. 

636 

637 This class is abstract and shouldn't be instantiated. Instantiating a 

638 class derived from this one results in the creation of a task object, 

639 which ultimately becomes a node in Dag objects. Task dependencies should 

640 be set by using the set_upstream and/or set_downstream methods. 

641 

642 :param task_id: a unique, meaningful id for the task 

643 :param owner: the owner of the task. Using a meaningful description 

644 (e.g. user/person/team/role name) to clarify ownership is recommended. 

645 :param email: the 'to' email address(es) used in email alerts. This can be a 

646 single email or multiple ones. Multiple addresses can be specified as a 

647 comma or semicolon separated string or by passing a list of strings. (deprecated) 

648 :param email_on_retry: Indicates whether email alerts should be sent when a 

649 task is retried (deprecated) 

650 :param email_on_failure: Indicates whether email alerts should be sent when 

651 a task failed (deprecated) 

652 :param retries: the number of retries that should be performed before 

653 failing the task 

654 :param retry_delay: delay between retries, can be set as ``timedelta`` or 

655 ``float`` seconds, which will be converted into ``timedelta``, 

656 the default is ``timedelta(seconds=300)``. 

657 :param retry_exponential_backoff: multiplier for exponential backoff between retries. 

658 Set to 0 to disable (constant delay). Set to 2.0 for standard exponential backoff 

659 (delay doubles with each retry). For example, with retry_delay=4min and 

660 retry_exponential_backoff=5, retries occur after 4min, 20min, 100min, etc. 

661 :param max_retry_delay: maximum delay interval between retries, can be set as 

662 ``timedelta`` or ``float`` seconds, which will be converted into ``timedelta``. 

663 :param start_date: The ``start_date`` for the task, determines 

664 the ``logical_date`` for the first task instance. The best practice 

665 is to have the start_date rounded 

666 to your Dag's ``schedule_interval``. Daily jobs have their start_date 

667 some day at 00:00:00, hourly jobs have their start_date at 00:00 

668 of a specific hour. Note that Airflow simply looks at the latest 

669 ``logical_date`` and adds the ``schedule_interval`` to determine 

670 the next ``logical_date``. It is also very important 

671 to note that different tasks' dependencies 

672 need to line up in time. If task A depends on task B and their 

673 start_date are offset in a way that their logical_date don't line 

674 up, A's dependencies will never be met. If you are looking to delay 

675 a task, for example running a daily task at 2AM, look into the 

676 ``TimeSensor`` and ``TimeDeltaSensor``. We advise against using 

677 dynamic ``start_date`` and recommend using fixed ones. Read the 

678 FAQ entry about start_date for more information. 

679 :param end_date: if specified, the scheduler won't go beyond this date 

680 :param depends_on_past: when set to true, task instances will run 

681 sequentially and only if the previous instance has succeeded or has been skipped. 

682 The task instance for the start_date is allowed to run. 

683 :param wait_for_past_depends_before_skipping: when set to true, if the task instance 

684 should be marked as skipped, and depends_on_past is true, the ti will stay on None state 

685 waiting the task of the previous run 

686 :param wait_for_downstream: when set to true, an instance of task 

687 X will wait for tasks immediately downstream of the previous instance 

688 of task X to finish successfully or be skipped before it runs. This is useful if the 

689 different instances of a task X alter the same asset, and this asset 

690 is used by tasks downstream of task X. Note that depends_on_past 

691 is forced to True wherever wait_for_downstream is used. Also note that 

692 only tasks *immediately* downstream of the previous task instance are waited 

693 for; the statuses of any tasks further downstream are ignored. 

694 :param dag: a reference to the dag the task is attached to (if any) 

695 :param priority_weight: priority weight of this task against other task. 

696 This allows the executor to trigger higher priority tasks before 

697 others when things get backed up. Set priority_weight as a higher 

698 number for more important tasks. 

699 As not all database engines support 64-bit integers, values are capped with 32-bit. 

700 Valid range is from -2,147,483,648 to 2,147,483,647. 

701 :param weight_rule: weighting method used for the effective total 

702 priority weight of the task. Options are: 

703 ``{ downstream | upstream | absolute }`` default is ``downstream`` 

704 When set to ``downstream`` the effective weight of the task is the 

705 aggregate sum of all downstream descendants. As a result, upstream 

706 tasks will have higher weight and will be scheduled more aggressively 

707 when using positive weight values. This is useful when you have 

708 multiple dag run instances and desire to have all upstream tasks to 

709 complete for all runs before each dag can continue processing 

710 downstream tasks. When set to ``upstream`` the effective weight is the 

711 aggregate sum of all upstream ancestors. This is the opposite where 

712 downstream tasks have higher weight and will be scheduled more 

713 aggressively when using positive weight values. This is useful when you 

714 have multiple dag run instances and prefer to have each dag complete 

715 before starting upstream tasks of other dags. When set to 

716 ``absolute``, the effective weight is the exact ``priority_weight`` 

717 specified without additional weighting. You may want to do this when 

718 you know exactly what priority weight each task should have. 

719 Additionally, when set to ``absolute``, there is bonus effect of 

720 significantly speeding up the task creation process as for very large 

721 Dags. Options can be set as string or using the constants defined in 

722 the static class ``airflow.utils.WeightRule``. 

723 Irrespective of the weight rule, resulting priority values are capped with 32-bit. 

724 |experimental| 

725 Since 2.9.0, Airflow allows to define custom priority weight strategy, 

726 by creating a subclass of 

727 ``airflow.task.priority_strategy.PriorityWeightStrategy`` and registering 

728 in a plugin, then providing the class path or the class instance via 

729 ``weight_rule`` parameter. The custom priority weight strategy will be 

730 used to calculate the effective total priority weight of the task instance. 

731 :param queue: which queue to target when running this job. Not 

732 all executors implement queue management, the CeleryExecutor 

733 does support targeting specific queues. 

734 :param pool: the slot pool this task should run in, slot pools are a 

735 way to limit concurrency for certain tasks 

736 :param pool_slots: the number of pool slots this task should use (>= 1) 

737 Values less than 1 are not allowed. 

738 :param sla: DEPRECATED - The SLA feature is removed in Airflow 3.0, to be replaced with a 

739 new implementation in Airflow >=3.1. 

740 :param execution_timeout: max time allowed for the execution of 

741 this task instance, if it goes beyond it will raise and fail. 

742 :param on_failure_callback: a function or list of functions to be called when a task instance 

743 of this task fails. a context dictionary is passed as a single 

744 parameter to this function. Context contains references to related 

745 objects to the task instance and is documented under the macros 

746 section of the API. 

747 :param on_execute_callback: much like the ``on_failure_callback`` except 

748 that it is executed right before the task is executed. 

749 :param on_retry_callback: much like the ``on_failure_callback`` except 

750 that it is executed when retries occur. 

751 :param on_success_callback: much like the ``on_failure_callback`` except 

752 that it is executed when the task succeeds. 

753 :param on_skipped_callback: much like the ``on_failure_callback`` except 

754 that it is executed when skipped occur; this callback will be called only if AirflowSkipException get raised. 

755 Explicitly it is NOT called if a task is not started to be executed because of a preceding branching 

756 decision in the Dag or a trigger rule which causes execution to skip so that the task execution 

757 is never scheduled. 

758 :param pre_execute: a function to be called immediately before task 

759 execution, receiving a context dictionary; raising an exception will 

760 prevent the task from being executed. 

761 

762 |experimental| 

763 :param post_execute: a function to be called immediately after task 

764 execution, receiving a context dictionary and task result; raising an 

765 exception will prevent the task from succeeding. 

766 

767 |experimental| 

768 :param trigger_rule: defines the rule by which dependencies are applied 

769 for the task to get triggered. Options are: 

770 ``{ all_success | all_failed | all_done | all_skipped | one_success | one_done | 

771 one_failed | none_failed | none_failed_min_one_success | none_skipped | always}`` 

772 default is ``all_success``. Options can be set as string or 

773 using the constants defined in the static class 

774 ``airflow.utils.TriggerRule`` 

775 :param resources: A map of resource parameter names (the argument names of the 

776 Resources constructor) to their values. 

777 :param run_as_user: unix username to impersonate while running the task 

778 :param max_active_tis_per_dag: When set, a task will be able to limit the concurrent 

779 runs across logical_dates. 

780 :param max_active_tis_per_dagrun: When set, a task will be able to limit the concurrent 

781 task instances per Dag run. 

782 :param executor: Which executor to target when running this task. NOT YET SUPPORTED 

783 :param executor_config: Additional task-level configuration parameters that are 

784 interpreted by a specific executor. Parameters are namespaced by the name of 

785 executor. 

786 

787 **Example**: to run this task in a specific docker container through 

788 the KubernetesExecutor :: 

789 

790 MyOperator(..., executor_config={"KubernetesExecutor": {"image": "myCustomDockerImage"}}) 

791 

792 :param do_xcom_push: if True, an XCom is pushed containing the Operator's 

793 result 

794 :param multiple_outputs: if True and do_xcom_push is True, pushes multiple XComs, one for each 

795 key in the returned dictionary result. If False and do_xcom_push is True, pushes a single XCom. 

796 :param task_group: The TaskGroup to which the task should belong. This is typically provided when not 

797 using a TaskGroup as a context manager. 

798 :param doc: Add documentation or notes to your Task objects that is visible in 

799 Task Instance details View in the Webserver 

800 :param doc_md: Add documentation (in Markdown format) or notes to your Task objects 

801 that is visible in Task Instance details View in the Webserver 

802 :param doc_rst: Add documentation (in RST format) or notes to your Task objects 

803 that is visible in Task Instance details View in the Webserver 

804 :param doc_json: Add documentation (in JSON format) or notes to your Task objects 

805 that is visible in Task Instance details View in the Webserver 

806 :param doc_yaml: Add documentation (in YAML format) or notes to your Task objects 

807 that is visible in Task Instance details View in the Webserver 

808 :param task_display_name: The display name of the task which appears on the UI. 

809 :param logger_name: Name of the logger used by the Operator to emit logs. 

810 If set to `None` (default), the logger name will fall back to 

811 `airflow.task.operators.{class.__module__}.{class.__name__}` (e.g. HttpOperator will have 

812 *airflow.task.operators.airflow.providers.http.operators.http.HttpOperator* as logger). 

813 :param allow_nested_operators: if True, when an operator is executed within another one a warning message 

814 will be logged. If False, then an exception will be raised if the operator is badly used (e.g. nested 

815 within another one). In future releases of Airflow this parameter will be removed and an exception 

816 will always be thrown when operators are nested within each other (default is True). 

817 

818 **Example**: example of a bad operator mixin usage:: 

819 

820 @task(provide_context=True) 

821 def say_hello_world(**context): 

822 hello_world_task = BashOperator( 

823 task_id="hello_world_task", 

824 bash_command="python -c \"print('Hello, world!')\"", 

825 dag=dag, 

826 ) 

827 hello_world_task.execute(context) 

828 """ 

829 

830 task_id: str 

831 owner: str = DEFAULT_OWNER 

832 email: str | Sequence[str] | None = None 

833 email_on_retry: bool = True 

834 email_on_failure: bool = True 

835 retries: int | None = DEFAULT_RETRIES 

836 retry_delay: timedelta = DEFAULT_RETRY_DELAY 

837 retry_exponential_backoff: float = 0 

838 max_retry_delay: timedelta | float | None = None 

839 start_date: datetime | None = None 

840 end_date: datetime | None = None 

841 depends_on_past: bool = False 

842 ignore_first_depends_on_past: bool = DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST 

843 wait_for_past_depends_before_skipping: bool = DEFAULT_WAIT_FOR_PAST_DEPENDS_BEFORE_SKIPPING 

844 wait_for_downstream: bool = False 

845 

846 # At execution_time this becomes a normal dict 

847 params: ParamsDict | dict = field(default_factory=ParamsDict) 

848 default_args: dict | None = None 

849 priority_weight: int = DEFAULT_PRIORITY_WEIGHT 

850 weight_rule: PriorityWeightStrategy = field( 

851 default_factory=airflow_priority_weight_strategies[DEFAULT_WEIGHT_RULE] 

852 ) 

853 queue: str = DEFAULT_QUEUE 

854 pool: str = DEFAULT_POOL_NAME 

855 pool_slots: int = DEFAULT_POOL_SLOTS 

856 execution_timeout: timedelta | None = DEFAULT_TASK_EXECUTION_TIMEOUT 

857 on_execute_callback: Sequence[TaskStateChangeCallback] = () 

858 on_failure_callback: Sequence[TaskStateChangeCallback] = () 

859 on_success_callback: Sequence[TaskStateChangeCallback] = () 

860 on_retry_callback: Sequence[TaskStateChangeCallback] = () 

861 on_skipped_callback: Sequence[TaskStateChangeCallback] = () 

862 _pre_execute_hook: TaskPreExecuteHook | None = None 

863 _post_execute_hook: TaskPostExecuteHook | None = None 

864 trigger_rule: TriggerRule = DEFAULT_TRIGGER_RULE 

865 resources: dict[str, Any] | None = None 

866 run_as_user: str | None = None 

867 task_concurrency: int | None = None 

868 map_index_template: str | None = None 

869 max_active_tis_per_dag: int | None = None 

870 max_active_tis_per_dagrun: int | None = None 

871 executor: str | None = None 

872 executor_config: dict | None = None 

873 do_xcom_push: bool = True 

874 multiple_outputs: bool = False 

875 inlets: list[Any] = field(default_factory=list) 

876 outlets: list[Any] = field(default_factory=list) 

877 task_group: TaskGroup | None = None 

878 doc: str | None = None 

879 doc_md: str | None = None 

880 doc_json: str | None = None 

881 doc_yaml: str | None = None 

882 doc_rst: str | None = None 

883 _task_display_name: str | None = None 

884 logger_name: str | None = None 

885 allow_nested_operators: bool = True 

886 

887 is_setup: bool = False 

888 is_teardown: bool = False 

889 

890 # TODO: Task-SDK: Make these ClassVar[]? 

891 template_fields: Collection[str] = () 

892 template_ext: Sequence[str] = () 

893 

894 template_fields_renderers: ClassVar[dict[str, str]] = {} 

895 

896 operator_extra_links: Collection[BaseOperatorLink] = () 

897 

898 # Defines the color in the UI 

899 ui_color: str = "#fff" 

900 ui_fgcolor: str = "#000" 

901 

902 partial: Callable[..., OperatorPartial] = _PartialDescriptor() # type: ignore 

903 

904 _dag: DAG | None = field(init=False, default=None) 

905 

906 # Make this optional so the type matches the one define in LoggingMixin 

907 _log_config_logger_name: str | None = field(default="airflow.task.operators", init=False) 

908 _logger_name: str | None = None 

909 

910 # The _serialized_fields are lazily loaded when get_serialized_fields() method is called 

911 __serialized_fields: ClassVar[frozenset[str] | None] = None 

912 

913 _comps: ClassVar[set[str]] = { 

914 "task_id", 

915 "dag_id", 

916 "owner", 

917 "email", 

918 "email_on_retry", 

919 "retry_delay", 

920 "retry_exponential_backoff", 

921 "max_retry_delay", 

922 "start_date", 

923 "end_date", 

924 "depends_on_past", 

925 "wait_for_downstream", 

926 "priority_weight", 

927 "execution_timeout", 

928 "has_on_execute_callback", 

929 "has_on_failure_callback", 

930 "has_on_success_callback", 

931 "has_on_retry_callback", 

932 "has_on_skipped_callback", 

933 "do_xcom_push", 

934 "multiple_outputs", 

935 "allow_nested_operators", 

936 "executor", 

937 } 

938 

939 # If True, the Rendered Template fields will be overwritten in DB after execution 

940 # This is useful for Taskflow decorators that modify the template fields during execution like 

941 # @task.bash decorator. 

942 overwrite_rtif_after_execution: bool = False 

943 

944 # If True then the class constructor was called 

945 __instantiated: bool = False 

946 # List of args as passed to `init()`, after apply_defaults() has been updated. Used to "recreate" the task 

947 # when mapping 

948 # Set via the metaclass 

949 __init_kwargs: dict[str, Any] = field(init=False) 

950 

951 # Set to True before calling execute method 

952 _lock_for_execution: bool = False 

953 

954 # Set to True for an operator instantiated by a mapped operator. 

955 __from_mapped: bool = False 

956 

957 start_trigger_args: StartTriggerArgs | None = None 

958 start_from_trigger: bool = False 

959 

960 # base list which includes all the attrs that don't need deep copy. 

961 _base_operator_shallow_copy_attrs: Final[tuple[str, ...]] = ( 

962 "user_defined_macros", 

963 "user_defined_filters", 

964 "params", 

965 ) 

966 

967 # each operator should override this class attr for shallow copy attrs. 

968 shallow_copy_attrs: Sequence[str] = () 

969 

970 def __setattr__(self: BaseOperator, key: str, value: Any): 

971 if converter := getattr(self, f"_convert_{key}", None): 

972 value = converter(value) 

973 super().__setattr__(key, value) 

974 if self.__from_mapped or self._lock_for_execution: 

975 return # Skip any custom behavior for validation and during execute. 

976 if key in self.__init_kwargs: 

977 self.__init_kwargs[key] = value 

978 if self.__instantiated and key in self.template_fields: 

979 # Resolve upstreams set by assigning an XComArg after initializing 

980 # an operator, example: 

981 # op = BashOperator() 

982 # op.bash_command = "sleep 1" 

983 self._set_xcomargs_dependency(key, value) 

984 

985 def __init__( 

986 self, 

987 *, 

988 task_id: str, 

989 owner: str = DEFAULT_OWNER, 

990 email: str | Sequence[str] | None = None, 

991 email_on_retry: bool = True, 

992 email_on_failure: bool = True, 

993 retries: int | None = DEFAULT_RETRIES, 

994 retry_delay: timedelta | float = DEFAULT_RETRY_DELAY, 

995 retry_exponential_backoff: float = 0, 

996 max_retry_delay: timedelta | float | None = None, 

997 start_date: datetime | None = None, 

998 end_date: datetime | None = None, 

999 depends_on_past: bool = False, 

1000 ignore_first_depends_on_past: bool = DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST, 

1001 wait_for_past_depends_before_skipping: bool = DEFAULT_WAIT_FOR_PAST_DEPENDS_BEFORE_SKIPPING, 

1002 wait_for_downstream: bool = False, 

1003 dag: DAG | None = None, 

1004 params: collections.abc.MutableMapping[str, Any] | None = None, 

1005 default_args: dict | None = None, 

1006 priority_weight: int = DEFAULT_PRIORITY_WEIGHT, 

1007 weight_rule: str | PriorityWeightStrategy = DEFAULT_WEIGHT_RULE, 

1008 queue: str = DEFAULT_QUEUE, 

1009 pool: str | None = None, 

1010 pool_slots: int = DEFAULT_POOL_SLOTS, 

1011 sla: timedelta | None = None, 

1012 execution_timeout: timedelta | None = DEFAULT_TASK_EXECUTION_TIMEOUT, 

1013 on_execute_callback: None | TaskStateChangeCallback | Collection[TaskStateChangeCallback] = None, 

1014 on_failure_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = None, 

1015 on_success_callback: None | TaskStateChangeCallback | Collection[TaskStateChangeCallback] = None, 

1016 on_retry_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = None, 

1017 on_skipped_callback: None | TaskStateChangeCallback | Collection[TaskStateChangeCallback] = None, 

1018 pre_execute: TaskPreExecuteHook | None = None, 

1019 post_execute: TaskPostExecuteHook | None = None, 

1020 trigger_rule: str = DEFAULT_TRIGGER_RULE, 

1021 resources: dict[str, Any] | None = None, 

1022 run_as_user: str | None = None, 

1023 map_index_template: str | None = None, 

1024 max_active_tis_per_dag: int | None = None, 

1025 max_active_tis_per_dagrun: int | None = None, 

1026 executor: str | None = None, 

1027 executor_config: dict | None = None, 

1028 do_xcom_push: bool = True, 

1029 multiple_outputs: bool = False, 

1030 inlets: Any | None = None, 

1031 outlets: Any | None = None, 

1032 task_group: TaskGroup | None = None, 

1033 doc: str | None = None, 

1034 doc_md: str | None = None, 

1035 doc_json: str | None = None, 

1036 doc_yaml: str | None = None, 

1037 doc_rst: str | None = None, 

1038 task_display_name: str | None = None, 

1039 logger_name: str | None = None, 

1040 allow_nested_operators: bool = True, 

1041 **kwargs: Any, 

1042 ): 

1043 # Note: Metaclass handles passing in the Dag/TaskGroup from active context manager, if any 

1044 

1045 # Only apply task_group prefix if this operator was not created from a mapped operator 

1046 # Mapped operators already have the prefix applied during their creation 

1047 if task_group and not self.__from_mapped: 

1048 self.task_id = task_group.child_id(task_id) 

1049 task_group.add(self) 

1050 else: 

1051 self.task_id = task_id 

1052 

1053 super().__init__() 

1054 self.task_group = task_group 

1055 

1056 kwargs.pop("_airflow_mapped_validation_only", None) 

1057 if kwargs: 

1058 raise TypeError( 

1059 f"Invalid arguments were passed to {self.__class__.__name__} (task_id: {task_id}). " 

1060 f"Invalid arguments were:\n**kwargs: {redact(kwargs)}", 

1061 ) 

1062 validate_key(self.task_id) 

1063 

1064 self.owner = owner 

1065 self.email = email 

1066 self.email_on_retry = email_on_retry 

1067 self.email_on_failure = email_on_failure 

1068 

1069 if email is not None: 

1070 warnings.warn( 

1071 "Setting email on a task is deprecated; please migrate to SmtpNotifier.", 

1072 RemovedInAirflow4Warning, 

1073 stacklevel=2, 

1074 ) 

1075 if email and email_on_retry is not None: 

1076 warnings.warn( 

1077 "Setting email_on_retry on a task is deprecated; please migrate to SmtpNotifier.", 

1078 RemovedInAirflow4Warning, 

1079 stacklevel=2, 

1080 ) 

1081 if email and email_on_failure is not None: 

1082 warnings.warn( 

1083 "Setting email_on_failure on a task is deprecated; please migrate to SmtpNotifier.", 

1084 RemovedInAirflow4Warning, 

1085 stacklevel=2, 

1086 ) 

1087 

1088 if execution_timeout is not None and not isinstance(execution_timeout, timedelta): 

1089 raise ValueError( 

1090 f"execution_timeout must be timedelta object but passed as type: {type(execution_timeout)}" 

1091 ) 

1092 self.execution_timeout = execution_timeout 

1093 

1094 self.on_execute_callback = _collect_from_input(on_execute_callback) 

1095 self.on_failure_callback = _collect_from_input(on_failure_callback) 

1096 self.on_success_callback = _collect_from_input(on_success_callback) 

1097 self.on_retry_callback = _collect_from_input(on_retry_callback) 

1098 self.on_skipped_callback = _collect_from_input(on_skipped_callback) 

1099 self._pre_execute_hook = pre_execute 

1100 self._post_execute_hook = post_execute 

1101 

1102 self.start_date = timezone.convert_to_utc(start_date) 

1103 self.end_date = timezone.convert_to_utc(end_date) 

1104 self.executor = executor 

1105 self.executor_config = executor_config or {} 

1106 self.run_as_user = run_as_user 

1107 # TODO: 

1108 # self.retries = parse_retries(retries) 

1109 self.retries = retries 

1110 self.queue = queue 

1111 self.pool = DEFAULT_POOL_NAME if pool is None else pool 

1112 self.pool_slots = pool_slots 

1113 if self.pool_slots < 1: 

1114 dag_str = f" in dag {dag.dag_id}" if dag else "" 

1115 raise ValueError(f"pool slots for {self.task_id}{dag_str} cannot be less than 1") 

1116 if sla is not None: 

1117 warnings.warn( 

1118 "The SLA feature is removed in Airflow 3.0, replaced with Deadline Alerts in >=3.1", 

1119 stacklevel=2, 

1120 ) 

1121 

1122 try: 

1123 TriggerRule(trigger_rule) 

1124 except ValueError: 

1125 raise ValueError( 

1126 f"The trigger_rule must be one of {[rule.value for rule in TriggerRule]}," 

1127 f"'{dag.dag_id if dag else ''}.{task_id}'; received '{trigger_rule}'." 

1128 ) 

1129 

1130 self.trigger_rule: TriggerRule = TriggerRule(trigger_rule) 

1131 

1132 self.depends_on_past: bool = depends_on_past 

1133 self.ignore_first_depends_on_past: bool = ignore_first_depends_on_past 

1134 self.wait_for_past_depends_before_skipping: bool = wait_for_past_depends_before_skipping 

1135 self.wait_for_downstream: bool = wait_for_downstream 

1136 if wait_for_downstream: 

1137 self.depends_on_past = True 

1138 

1139 # Converted by setattr 

1140 self.retry_delay = retry_delay # type: ignore[assignment] 

1141 self.retry_exponential_backoff = retry_exponential_backoff 

1142 if max_retry_delay is not None: 

1143 self.max_retry_delay = max_retry_delay 

1144 

1145 self.resources = resources 

1146 

1147 self.params = ParamsDict(params) 

1148 

1149 self.priority_weight = priority_weight 

1150 self.weight_rule = validate_and_load_priority_weight_strategy(weight_rule) 

1151 

1152 self.max_active_tis_per_dag: int | None = max_active_tis_per_dag 

1153 self.max_active_tis_per_dagrun: int | None = max_active_tis_per_dagrun 

1154 self.do_xcom_push: bool = do_xcom_push 

1155 self.map_index_template: str | None = map_index_template 

1156 self.multiple_outputs: bool = multiple_outputs 

1157 

1158 self.doc_md = doc_md 

1159 self.doc_json = doc_json 

1160 self.doc_yaml = doc_yaml 

1161 self.doc_rst = doc_rst 

1162 self.doc = doc 

1163 

1164 self._task_display_name = task_display_name 

1165 

1166 self.allow_nested_operators = allow_nested_operators 

1167 

1168 self._logger_name = logger_name 

1169 

1170 # Lineage 

1171 self.inlets = _collect_from_input(inlets) 

1172 self.outlets = _collect_from_input(outlets) 

1173 

1174 if isinstance(self.template_fields, str): 

1175 warnings.warn( 

1176 f"The `template_fields` value for {self.task_type} is a string " 

1177 "but should be a list or tuple of string. Wrapping it in a list for execution. " 

1178 f"Please update {self.task_type} accordingly.", 

1179 UserWarning, 

1180 stacklevel=2, 

1181 ) 

1182 self.template_fields = [self.template_fields] 

1183 

1184 self.is_setup = False 

1185 self.is_teardown = False 

1186 

1187 if SetupTeardownContext.active: 

1188 SetupTeardownContext.update_context_map(self) 

1189 

1190 # We set self.dag right at the end as `_convert_dag` calls `dag.add_task` for us, and we need all the 

1191 # other properties to be set at that point 

1192 if dag is not None: 

1193 self.dag = dag 

1194 

1195 validate_instance_args(self, BASEOPERATOR_ARGS_EXPECTED_TYPES) 

1196 

1197 # Ensure priority_weight is within the valid range 

1198 self.priority_weight = db_safe_priority(self.priority_weight) 

1199 

1200 def __eq__(self, other): 

1201 if type(self) is type(other): 

1202 # Use getattr() instead of __dict__ as __dict__ doesn't return 

1203 # correct values for properties. 

1204 return all(getattr(self, c, None) == getattr(other, c, None) for c in self._comps) 

1205 return False 

1206 

1207 def __ne__(self, other): 

1208 return not self == other 

1209 

1210 def __hash__(self): 

1211 hash_components = [type(self)] 

1212 for component in self._comps: 

1213 val = getattr(self, component, None) 

1214 try: 

1215 hash(val) 

1216 hash_components.append(val) 

1217 except TypeError: 

1218 hash_components.append(repr(val)) 

1219 return hash(tuple(hash_components)) 

1220 

1221 # /Composing Operators --------------------------------------------- 

1222 

1223 def __gt__(self, other): 

1224 """ 

1225 Return [Operator] > [Outlet]. 

1226 

1227 If other is an attr annotated object it is set as an outlet of this Operator. 

1228 """ 

1229 if not isinstance(other, Iterable): 

1230 other = [other] 

1231 

1232 for obj in other: 

1233 if not attrs.has(obj): 

1234 raise TypeError(f"Left hand side ({obj}) is not an outlet") 

1235 self.add_outlets(other) 

1236 

1237 return self 

1238 

1239 def __lt__(self, other): 

1240 """ 

1241 Return [Inlet] > [Operator] or [Operator] < [Inlet]. 

1242 

1243 If other is an attr annotated object it is set as an inlet to this operator. 

1244 """ 

1245 if not isinstance(other, Iterable): 

1246 other = [other] 

1247 

1248 for obj in other: 

1249 if not attrs.has(obj): 

1250 raise TypeError(f"{obj} cannot be an inlet") 

1251 self.add_inlets(other) 

1252 

1253 return self 

1254 

1255 def __deepcopy__(self, memo: dict[int, Any]): 

1256 # Hack sorting double chained task lists by task_id to avoid hitting 

1257 # max_depth on deepcopy operations. 

1258 sys.setrecursionlimit(5000) # TODO fix this in a better way 

1259 

1260 cls = self.__class__ 

1261 result = cls.__new__(cls) 

1262 memo[id(self)] = result 

1263 

1264 shallow_copy = tuple(cls.shallow_copy_attrs) + cls._base_operator_shallow_copy_attrs 

1265 

1266 for k, v_org in self.__dict__.items(): 

1267 if k not in shallow_copy: 

1268 v = copy.deepcopy(v_org, memo) 

1269 else: 

1270 v = copy.copy(v_org) 

1271 

1272 # Bypass any setters, and set it on the object directly. This works since we are cloning ourself so 

1273 # we know the type is already fine 

1274 result.__dict__[k] = v 

1275 return result 

1276 

1277 def __getstate__(self): 

1278 state = dict(self.__dict__) 

1279 if "_log" in state: 

1280 del state["_log"] 

1281 

1282 return state 

1283 

1284 def __setstate__(self, state): 

1285 self.__dict__ = state 

1286 

1287 def add_inlets(self, inlets: Iterable[Any]): 

1288 """Set inlets to this operator.""" 

1289 self.inlets.extend(inlets) 

1290 

1291 def add_outlets(self, outlets: Iterable[Any]): 

1292 """Define the outlets of this operator.""" 

1293 self.outlets.extend(outlets) 

1294 

1295 def get_dag(self) -> DAG | None: 

1296 return self._dag 

1297 

1298 @property 

1299 def dag(self) -> DAG: 

1300 """Returns the Operator's Dag if set, otherwise raises an error.""" 

1301 if dag := self._dag: 

1302 return dag 

1303 raise RuntimeError(f"Operator {self} has not been assigned to a Dag yet") 

1304 

1305 @dag.setter 

1306 def dag(self, dag: DAG | None) -> None: 

1307 """Operators can be assigned to one Dag, one time. Repeat assignments to that same Dag are ok.""" 

1308 self._dag = dag 

1309 

1310 def _convert__dag(self, dag: DAG | None) -> DAG | None: 

1311 # Called automatically by __setattr__ method 

1312 from airflow.sdk.definitions.dag import DAG 

1313 

1314 if dag is None: 

1315 return dag 

1316 

1317 if not isinstance(dag, DAG): 

1318 raise TypeError(f"Expected dag; received {dag.__class__.__name__}") 

1319 if self._dag is not None and self._dag is not dag: 

1320 raise ValueError(f"The dag assigned to {self} can not be changed.") 

1321 

1322 if self.__from_mapped: 

1323 pass # Don't add to dag -- the mapped task takes the place. 

1324 elif dag.task_dict.get(self.task_id) is not self: 

1325 dag.add_task(self) 

1326 return dag 

1327 

1328 @staticmethod 

1329 def _convert_retries(retries: Any) -> int | None: 

1330 if retries is None: 

1331 return 0 

1332 if type(retries) == int: # noqa: E721 

1333 return retries 

1334 try: 

1335 parsed_retries = int(retries) 

1336 except (TypeError, ValueError): 

1337 raise TypeError(f"'retries' type must be int, not {type(retries).__name__}") 

1338 return parsed_retries 

1339 

1340 @staticmethod 

1341 def _convert_timedelta(value: float | timedelta | None) -> timedelta | None: 

1342 if value is None or isinstance(value, timedelta): 

1343 return value 

1344 return timedelta(seconds=value) 

1345 

1346 _convert_retry_delay = _convert_timedelta 

1347 _convert_max_retry_delay = _convert_timedelta 

1348 

1349 @staticmethod 

1350 def _convert_resources(resources: dict[str, Any] | None) -> Resources | None: 

1351 if resources is None: 

1352 return None 

1353 

1354 from airflow.sdk.definitions.operator_resources import Resources 

1355 

1356 if isinstance(resources, Resources): 

1357 return resources 

1358 

1359 return Resources(**resources) 

1360 

1361 def _convert_is_setup(self, value: bool) -> bool: 

1362 """ 

1363 Setter for is_setup property. 

1364 

1365 :meta private: 

1366 """ 

1367 if self.is_teardown and value: 

1368 raise ValueError(f"Cannot mark task '{self.task_id}' as setup; task is already a teardown.") 

1369 return value 

1370 

1371 def _convert_is_teardown(self, value: bool) -> bool: 

1372 if self.is_setup and value: 

1373 raise ValueError(f"Cannot mark task '{self.task_id}' as teardown; task is already a setup.") 

1374 return value 

1375 

1376 @property 

1377 def task_display_name(self) -> str: 

1378 return self._task_display_name or self.task_id 

1379 

1380 def has_dag(self): 

1381 """Return True if the Operator has been assigned to a Dag.""" 

1382 return self._dag is not None 

1383 

1384 def _set_xcomargs_dependencies(self) -> None: 

1385 from airflow.sdk.definitions.xcom_arg import XComArg 

1386 

1387 for f in self.template_fields: 

1388 arg = getattr(self, f, NOTSET) 

1389 if arg is not NOTSET: 

1390 XComArg.apply_upstream_relationship(self, arg) 

1391 

1392 def _set_xcomargs_dependency(self, field: str, newvalue: Any) -> None: 

1393 """ 

1394 Resolve upstream dependencies of a task. 

1395 

1396 In this way passing an ``XComArg`` as value for a template field 

1397 will result in creating upstream relation between two tasks. 

1398 

1399 **Example**: :: 

1400 

1401 with DAG(...): 

1402 generate_content = GenerateContentOperator(task_id="generate_content") 

1403 send_email = EmailOperator(..., html_content=generate_content.output) 

1404 

1405 # This is equivalent to 

1406 with DAG(...): 

1407 generate_content = GenerateContentOperator(task_id="generate_content") 

1408 send_email = EmailOperator( 

1409 ..., html_content="{{ task_instance.xcom_pull('generate_content') }}" 

1410 ) 

1411 generate_content >> send_email 

1412 

1413 """ 

1414 from airflow.sdk.definitions.xcom_arg import XComArg 

1415 

1416 if field not in self.template_fields: 

1417 return 

1418 XComArg.apply_upstream_relationship(self, newvalue) 

1419 

1420 def on_kill(self) -> None: 

1421 """ 

1422 Override this method to clean up subprocesses when a task instance gets killed. 

1423 

1424 Any use of the threading, subprocess or multiprocessing module within an 

1425 operator needs to be cleaned up, or it will leave ghost processes behind. 

1426 """ 

1427 

1428 def __repr__(self): 

1429 return f"<Task({self.task_type}): {self.task_id}>" 

1430 

1431 @property 

1432 def operator_class(self) -> type[BaseOperator]: # type: ignore[override] 

1433 return self.__class__ 

1434 

1435 @property 

1436 def task_type(self) -> str: 

1437 """@property: type of the task.""" 

1438 return self.__class__.__name__ 

1439 

1440 @property 

1441 def operator_name(self) -> str: 

1442 """@property: use a more friendly display name for the operator, if set.""" 

1443 try: 

1444 return self.custom_operator_name # type: ignore 

1445 except AttributeError: 

1446 return self.task_type 

1447 

1448 @property 

1449 def roots(self) -> list[BaseOperator]: 

1450 """Required by DAGNode.""" 

1451 return [self] 

1452 

1453 @property 

1454 def leaves(self) -> list[BaseOperator]: 

1455 """Required by DAGNode.""" 

1456 return [self] 

1457 

1458 @property 

1459 def output(self) -> XComArg: 

1460 """Returns reference to XCom pushed by current operator.""" 

1461 from airflow.sdk.definitions.xcom_arg import XComArg 

1462 

1463 return XComArg(operator=self) 

1464 

1465 @classmethod 

1466 def get_serialized_fields(cls): 

1467 """Stringified Dags and operators contain exactly these fields.""" 

1468 if not cls.__serialized_fields: 

1469 from airflow.sdk.definitions._internal.contextmanager import DagContext 

1470 

1471 # make sure the following "fake" task is not added to current active 

1472 # dag in context, otherwise, it will result in 

1473 # `RuntimeError: dictionary changed size during iteration` 

1474 # Exception in SerializedDAG.serialize_dag() call. 

1475 DagContext.push(None) 

1476 cls.__serialized_fields = frozenset( 

1477 vars(BaseOperator(task_id="test")).keys() 

1478 - { 

1479 "upstream_task_ids", 

1480 "default_args", 

1481 "dag", 

1482 "_dag", 

1483 "label", 

1484 "_BaseOperator__instantiated", 

1485 "_BaseOperator__init_kwargs", 

1486 "_BaseOperator__from_mapped", 

1487 "on_failure_fail_dagrun", 

1488 "task_group", 

1489 "_task_type", 

1490 "operator_extra_links", 

1491 "on_execute_callback", 

1492 "on_failure_callback", 

1493 "on_success_callback", 

1494 "on_retry_callback", 

1495 "on_skipped_callback", 

1496 } 

1497 | { # Class level defaults, or `@property` need to be added to this list 

1498 "start_date", 

1499 "end_date", 

1500 "task_type", 

1501 "ui_color", 

1502 "ui_fgcolor", 

1503 "template_ext", 

1504 "template_fields", 

1505 "template_fields_renderers", 

1506 "params", 

1507 "is_setup", 

1508 "is_teardown", 

1509 "on_failure_fail_dagrun", 

1510 "map_index_template", 

1511 "start_trigger_args", 

1512 "_needs_expansion", 

1513 "start_from_trigger", 

1514 "max_retry_delay", 

1515 "has_on_execute_callback", 

1516 "has_on_failure_callback", 

1517 "has_on_success_callback", 

1518 "has_on_retry_callback", 

1519 "has_on_skipped_callback", 

1520 } 

1521 ) 

1522 DagContext.pop() 

1523 

1524 return cls.__serialized_fields 

1525 

1526 def prepare_for_execution(self) -> Self: 

1527 """Lock task for execution to disable custom action in ``__setattr__`` and return a copy.""" 

1528 other = copy.copy(self) 

1529 other._lock_for_execution = True 

1530 return other 

1531 

1532 def serialize_for_task_group(self) -> tuple[DagAttributeTypes, Any]: 

1533 """Serialize; required by DAGNode.""" 

1534 from airflow.serialization.enums import DagAttributeTypes 

1535 

1536 return DagAttributeTypes.OP, self.task_id 

1537 

1538 def unmap(self, resolve: None | Mapping[str, Any]) -> Self: 

1539 """ 

1540 Get the "normal" operator from the current operator. 

1541 

1542 Since a BaseOperator is not mapped to begin with, this simply returns 

1543 the original operator. 

1544 

1545 :meta private: 

1546 """ 

1547 return self 

1548 

1549 def expand_start_trigger_args(self, *, context: Context) -> StartTriggerArgs | None: 

1550 """ 

1551 Get the start_trigger_args value of the current abstract operator. 

1552 

1553 Since a BaseOperator is not mapped to begin with, this simply returns 

1554 the original value of start_trigger_args. 

1555 

1556 :meta private: 

1557 """ 

1558 return self.start_trigger_args 

1559 

1560 def render_template_fields( 

1561 self, 

1562 context: Context, 

1563 jinja_env: jinja2.Environment | None = None, 

1564 ) -> None: 

1565 """ 

1566 Template all attributes listed in *self.template_fields*. 

1567 

1568 This mutates the attributes in-place and is irreversible. 

1569 

1570 :param context: Context dict with values to apply on content. 

1571 :param jinja_env: Jinja's environment to use for rendering. 

1572 """ 

1573 if not jinja_env: 

1574 jinja_env = self.get_template_env() 

1575 self._do_render_template_fields(self, self.template_fields, context, jinja_env, set()) 

1576 

1577 def pre_execute(self, context: Any): 

1578 """Execute right before self.execute() is called.""" 

1579 

1580 def execute(self, context: Context) -> Any: 

1581 """ 

1582 Derive when creating an operator. 

1583 

1584 The main method to execute the task. Context is the same dictionary used 

1585 as when rendering jinja templates. 

1586 

1587 Refer to get_template_context for more context. 

1588 """ 

1589 raise NotImplementedError() 

1590 

1591 def post_execute(self, context: Any, result: Any = None): 

1592 """ 

1593 Execute right after self.execute() is called. 

1594 

1595 It is passed the execution context and any results returned by the operator. 

1596 """ 

1597 

1598 def defer( 

1599 self, 

1600 *, 

1601 trigger: BaseTrigger, 

1602 method_name: str, 

1603 kwargs: dict[str, Any] | None = None, 

1604 timeout: timedelta | int | float | None = None, 

1605 ) -> NoReturn: 

1606 """ 

1607 Mark this Operator "deferred", suspending its execution until the provided trigger fires an event. 

1608 

1609 This is achieved by raising a special exception (TaskDeferred) 

1610 which is caught in the main _execute_task wrapper. Triggers can send execution back to task or end 

1611 the task instance directly. If the trigger will end the task instance itself, ``method_name`` should 

1612 be None; otherwise, provide the name of the method that should be used when resuming execution in 

1613 the task. 

1614 """ 

1615 from airflow.sdk.exceptions import TaskDeferred 

1616 

1617 raise TaskDeferred(trigger=trigger, method_name=method_name, kwargs=kwargs, timeout=timeout) 

1618 

1619 def resume_execution(self, next_method: str, next_kwargs: dict[str, Any] | None, context: Context): 

1620 """Entrypoint method called by the Task Runner (instead of execute) when this task is resumed.""" 

1621 from airflow.sdk.exceptions import TaskDeferralError, TaskDeferralTimeout 

1622 

1623 if next_kwargs is None: 

1624 next_kwargs = {} 

1625 # __fail__ is a special signal value for next_method that indicates 

1626 # this task was scheduled specifically to fail. 

1627 

1628 if next_method == TRIGGER_FAIL_REPR: 

1629 next_kwargs = next_kwargs or {} 

1630 traceback = next_kwargs.get("traceback") 

1631 if traceback is not None: 

1632 self.log.error("Trigger failed:\n%s", "\n".join(traceback)) 

1633 if (error := next_kwargs.get("error", "Unknown")) == TriggerFailureReason.TRIGGER_TIMEOUT: 

1634 raise TaskDeferralTimeout(error) 

1635 raise TaskDeferralError(error) 

1636 # Grab the callable off the Operator/Task and add in any kwargs 

1637 execute_callable = getattr(self, next_method) 

1638 return execute_callable(context, **next_kwargs) 

1639 

1640 def dry_run(self) -> None: 

1641 """Perform dry run for the operator - just render template fields.""" 

1642 self.log.info("Dry run") 

1643 for f in self.template_fields: 

1644 try: 

1645 content = getattr(self, f) 

1646 except AttributeError: 

1647 raise AttributeError( 

1648 f"{f!r} is configured as a template field " 

1649 f"but {self.task_type} does not have this attribute." 

1650 ) 

1651 

1652 if content and isinstance(content, str): 

1653 self.log.info("Rendering template for %s", f) 

1654 self.log.info(content) 

1655 

1656 @property 

1657 def has_on_execute_callback(self) -> bool: 

1658 """Return True if the task has execute callbacks.""" 

1659 return bool(self.on_execute_callback) 

1660 

1661 @property 

1662 def has_on_failure_callback(self) -> bool: 

1663 """Return True if the task has failure callbacks.""" 

1664 return bool(self.on_failure_callback) 

1665 

1666 @property 

1667 def has_on_success_callback(self) -> bool: 

1668 """Return True if the task has success callbacks.""" 

1669 return bool(self.on_success_callback) 

1670 

1671 @property 

1672 def has_on_retry_callback(self) -> bool: 

1673 """Return True if the task has retry callbacks.""" 

1674 return bool(self.on_retry_callback) 

1675 

1676 @property 

1677 def has_on_skipped_callback(self) -> bool: 

1678 """Return True if the task has skipped callbacks.""" 

1679 return bool(self.on_skipped_callback) 

1680 

1681 

1682def chain(*tasks: DependencyMixin | Sequence[DependencyMixin]) -> None: 

1683 r""" 

1684 Given a number of tasks, builds a dependency chain. 

1685 

1686 This function accepts values of BaseOperator (aka tasks), EdgeModifiers (aka Labels), XComArg, TaskGroups, 

1687 or lists containing any mix of these types (or a mix in the same list). If you want to chain between two 

1688 lists you must ensure they have the same length. 

1689 

1690 Using classic operators/sensors: 

1691 

1692 .. code-block:: python 

1693 

1694 chain(t1, [t2, t3], [t4, t5], t6) 

1695 

1696 is equivalent to:: 

1697 

1698 / -> t2 -> t4 \ 

1699 t1 -> t6 

1700 \ -> t3 -> t5 / 

1701 

1702 .. code-block:: python 

1703 

1704 t1.set_downstream(t2) 

1705 t1.set_downstream(t3) 

1706 t2.set_downstream(t4) 

1707 t3.set_downstream(t5) 

1708 t4.set_downstream(t6) 

1709 t5.set_downstream(t6) 

1710 

1711 Using task-decorated functions aka XComArgs: 

1712 

1713 .. code-block:: python 

1714 

1715 chain(x1(), [x2(), x3()], [x4(), x5()], x6()) 

1716 

1717 is equivalent to:: 

1718 

1719 / -> x2 -> x4 \ 

1720 x1 -> x6 

1721 \ -> x3 -> x5 / 

1722 

1723 .. code-block:: python 

1724 

1725 x1 = x1() 

1726 x2 = x2() 

1727 x3 = x3() 

1728 x4 = x4() 

1729 x5 = x5() 

1730 x6 = x6() 

1731 x1.set_downstream(x2) 

1732 x1.set_downstream(x3) 

1733 x2.set_downstream(x4) 

1734 x3.set_downstream(x5) 

1735 x4.set_downstream(x6) 

1736 x5.set_downstream(x6) 

1737 

1738 Using TaskGroups: 

1739 

1740 .. code-block:: python 

1741 

1742 chain(t1, task_group1, task_group2, t2) 

1743 

1744 t1.set_downstream(task_group1) 

1745 task_group1.set_downstream(task_group2) 

1746 task_group2.set_downstream(t2) 

1747 

1748 

1749 It is also possible to mix between classic operator/sensor, EdgeModifiers, XComArg, and TaskGroups: 

1750 

1751 .. code-block:: python 

1752 

1753 chain(t1, [Label("branch one"), Label("branch two")], [x1(), x2()], task_group1, x3()) 

1754 

1755 is equivalent to:: 

1756 

1757 / "branch one" -> x1 \ 

1758 t1 -> task_group1 -> x3 

1759 \ "branch two" -> x2 / 

1760 

1761 .. code-block:: python 

1762 

1763 x1 = x1() 

1764 x2 = x2() 

1765 x3 = x3() 

1766 label1 = Label("branch one") 

1767 label2 = Label("branch two") 

1768 t1.set_downstream(label1) 

1769 label1.set_downstream(x1) 

1770 t2.set_downstream(label2) 

1771 label2.set_downstream(x2) 

1772 x1.set_downstream(task_group1) 

1773 x2.set_downstream(task_group1) 

1774 task_group1.set_downstream(x3) 

1775 

1776 # or 

1777 

1778 x1 = x1() 

1779 x2 = x2() 

1780 x3 = x3() 

1781 t1.set_downstream(x1, edge_modifier=Label("branch one")) 

1782 t1.set_downstream(x2, edge_modifier=Label("branch two")) 

1783 x1.set_downstream(task_group1) 

1784 x2.set_downstream(task_group1) 

1785 task_group1.set_downstream(x3) 

1786 

1787 

1788 :param tasks: Individual and/or list of tasks, EdgeModifiers, XComArgs, or TaskGroups to set dependencies 

1789 """ 

1790 for up_task, down_task in zip(tasks, tasks[1:]): 

1791 if isinstance(up_task, DependencyMixin): 

1792 up_task.set_downstream(down_task) 

1793 continue 

1794 if isinstance(down_task, DependencyMixin): 

1795 down_task.set_upstream(up_task) 

1796 continue 

1797 if not isinstance(up_task, Sequence) or not isinstance(down_task, Sequence): 

1798 raise TypeError(f"Chain not supported between instances of {type(up_task)} and {type(down_task)}") 

1799 up_task_list = up_task 

1800 down_task_list = down_task 

1801 if len(up_task_list) != len(down_task_list): 

1802 raise ValueError( 

1803 f"Chain not supported for different length Iterable. " 

1804 f"Got {len(up_task_list)} and {len(down_task_list)}." 

1805 ) 

1806 for up_t, down_t in zip(up_task_list, down_task_list): 

1807 up_t.set_downstream(down_t) 

1808 

1809 

1810def cross_downstream( 

1811 from_tasks: Sequence[DependencyMixin], 

1812 to_tasks: DependencyMixin | Sequence[DependencyMixin], 

1813): 

1814 r""" 

1815 Set downstream dependencies for all tasks in from_tasks to all tasks in to_tasks. 

1816 

1817 Using classic operators/sensors: 

1818 

1819 .. code-block:: python 

1820 

1821 cross_downstream(from_tasks=[t1, t2, t3], to_tasks=[t4, t5, t6]) 

1822 

1823 is equivalent to:: 

1824 

1825 t1 ---> t4 

1826 \ / 

1827 t2 -X -> t5 

1828 / \ 

1829 t3 ---> t6 

1830 

1831 .. code-block:: python 

1832 

1833 t1.set_downstream(t4) 

1834 t1.set_downstream(t5) 

1835 t1.set_downstream(t6) 

1836 t2.set_downstream(t4) 

1837 t2.set_downstream(t5) 

1838 t2.set_downstream(t6) 

1839 t3.set_downstream(t4) 

1840 t3.set_downstream(t5) 

1841 t3.set_downstream(t6) 

1842 

1843 Using task-decorated functions aka XComArgs: 

1844 

1845 .. code-block:: python 

1846 

1847 cross_downstream(from_tasks=[x1(), x2(), x3()], to_tasks=[x4(), x5(), x6()]) 

1848 

1849 is equivalent to:: 

1850 

1851 x1 ---> x4 

1852 \ / 

1853 x2 -X -> x5 

1854 / \ 

1855 x3 ---> x6 

1856 

1857 .. code-block:: python 

1858 

1859 x1 = x1() 

1860 x2 = x2() 

1861 x3 = x3() 

1862 x4 = x4() 

1863 x5 = x5() 

1864 x6 = x6() 

1865 x1.set_downstream(x4) 

1866 x1.set_downstream(x5) 

1867 x1.set_downstream(x6) 

1868 x2.set_downstream(x4) 

1869 x2.set_downstream(x5) 

1870 x2.set_downstream(x6) 

1871 x3.set_downstream(x4) 

1872 x3.set_downstream(x5) 

1873 x3.set_downstream(x6) 

1874 

1875 It is also possible to mix between classic operator/sensor and XComArg tasks: 

1876 

1877 .. code-block:: python 

1878 

1879 cross_downstream(from_tasks=[t1, x2(), t3], to_tasks=[x1(), t2, x3()]) 

1880 

1881 is equivalent to:: 

1882 

1883 t1 ---> x1 

1884 \ / 

1885 x2 -X -> t2 

1886 / \ 

1887 t3 ---> x3 

1888 

1889 .. code-block:: python 

1890 

1891 x1 = x1() 

1892 x2 = x2() 

1893 x3 = x3() 

1894 t1.set_downstream(x1) 

1895 t1.set_downstream(t2) 

1896 t1.set_downstream(x3) 

1897 x2.set_downstream(x1) 

1898 x2.set_downstream(t2) 

1899 x2.set_downstream(x3) 

1900 t3.set_downstream(x1) 

1901 t3.set_downstream(t2) 

1902 t3.set_downstream(x3) 

1903 

1904 :param from_tasks: List of tasks or XComArgs to start from. 

1905 :param to_tasks: List of tasks or XComArgs to set as downstream dependencies. 

1906 """ 

1907 for task in from_tasks: 

1908 task.set_downstream(to_tasks) 

1909 

1910 

1911def chain_linear(*elements: DependencyMixin | Sequence[DependencyMixin]): 

1912 """ 

1913 Simplify task dependency definition. 

1914 

1915 E.g.: suppose you want precedence like so:: 

1916 

1917 ╭─op2─╮ ╭─op4─╮ 

1918 op1─┤ ├─├─op5─┤─op7 

1919 ╰-op3─╯ ╰-op6─╯ 

1920 

1921 Then you can accomplish like so:: 

1922 

1923 chain_linear(op1, [op2, op3], [op4, op5, op6], op7) 

1924 

1925 :param elements: a list of operators / lists of operators 

1926 """ 

1927 if not elements: 

1928 raise ValueError("No tasks provided; nothing to do.") 

1929 prev_elem = None 

1930 deps_set = False 

1931 for curr_elem in elements: 

1932 if isinstance(curr_elem, EdgeModifier): 

1933 raise ValueError("Labels are not supported by chain_linear") 

1934 if prev_elem is not None: 

1935 for task in prev_elem: 

1936 task >> curr_elem 

1937 if not deps_set: 

1938 deps_set = True 

1939 prev_elem = [curr_elem] if isinstance(curr_elem, DependencyMixin) else curr_elem 

1940 if not deps_set: 

1941 raise ValueError("No dependencies were set. Did you forget to expand with `*`?")