Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/airflow/sdk/bases/operator.py: 36%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

679 statements  

1# Licensed to the Apache Software Foundation (ASF) under one 

2# or more contributor license agreements. See the NOTICE file 

3# distributed with this work for additional information 

4# regarding copyright ownership. The ASF licenses this file 

5# to you under the Apache License, Version 2.0 (the 

6# "License"); you may not use this file except in compliance 

7# with the License. You may obtain a copy of the License at 

8# 

9# http://www.apache.org/licenses/LICENSE-2.0 

10# 

11# Unless required by applicable law or agreed to in writing, 

12# software distributed under the License is distributed on an 

13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 

14# KIND, either express or implied. See the License for the 

15# specific language governing permissions and limitations 

16# under the License. 

17 

18from __future__ import annotations 

19 

20import abc 

21import collections.abc 

22import contextlib 

23import copy 

24import inspect 

25import sys 

26import warnings 

27from collections.abc import Callable, Collection, Iterable, Mapping, Sequence 

28from contextvars import ContextVar 

29from dataclasses import dataclass, field 

30from datetime import datetime, timedelta 

31from enum import Enum 

32from functools import total_ordering, wraps 

33from types import FunctionType 

34from typing import TYPE_CHECKING, Any, ClassVar, Final, NoReturn, TypeVar, cast 

35 

36import attrs 

37 

38from airflow.sdk import TriggerRule, timezone 

39from airflow.sdk._shared.secrets_masker import redact 

40from airflow.sdk.definitions._internal.abstractoperator import ( 

41 DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST, 

42 DEFAULT_OWNER, 

43 DEFAULT_POOL_NAME, 

44 DEFAULT_POOL_SLOTS, 

45 DEFAULT_PRIORITY_WEIGHT, 

46 DEFAULT_QUEUE, 

47 DEFAULT_RETRIES, 

48 DEFAULT_RETRY_DELAY, 

49 DEFAULT_TASK_EXECUTION_TIMEOUT, 

50 DEFAULT_TRIGGER_RULE, 

51 DEFAULT_WAIT_FOR_PAST_DEPENDS_BEFORE_SKIPPING, 

52 DEFAULT_WEIGHT_RULE, 

53 AbstractOperator, 

54 DependencyMixin, 

55 TaskStateChangeCallback, 

56) 

57from airflow.sdk.definitions._internal.decorators import fixup_decorator_warning_stack 

58from airflow.sdk.definitions._internal.node import validate_key 

59from airflow.sdk.definitions._internal.setup_teardown import SetupTeardownContext 

60from airflow.sdk.definitions._internal.types import NOTSET, validate_instance_args 

61from airflow.sdk.definitions.edges import EdgeModifier 

62from airflow.sdk.definitions.mappedoperator import OperatorPartial, validate_mapping_kwargs 

63from airflow.sdk.definitions.param import ParamsDict 

64from airflow.sdk.exceptions import RemovedInAirflow4Warning 

65from airflow.task.priority_strategy import ( 

66 PriorityWeightStrategy, 

67 airflow_priority_weight_strategies, 

68 validate_and_load_priority_weight_strategy, 

69) 

70 

71# Databases do not support arbitrary precision integers, so we need to limit the range of priority weights. 

72# postgres: -2147483648 to +2147483647 (see https://www.postgresql.org/docs/current/datatype-numeric.html) 

73# mysql: -2147483648 to +2147483647 (see https://dev.mysql.com/doc/refman/8.4/en/integer-types.html) 

74# sqlite: -9223372036854775808 to +9223372036854775807 (see https://sqlite.org/datatype3.html) 

75DB_SAFE_MINIMUM = -2147483648 

76DB_SAFE_MAXIMUM = 2147483647 

77 

78 

79def db_safe_priority(priority_weight: int) -> int: 

80 """Convert priority weight to a safe value for the database.""" 

81 return max(DB_SAFE_MINIMUM, min(DB_SAFE_MAXIMUM, priority_weight)) 

82 

83 

84C = TypeVar("C", bound=Callable) 

85T = TypeVar("T", bound=FunctionType) 

86 

87if TYPE_CHECKING: 

88 from types import ClassMethodDescriptorType 

89 

90 import jinja2 

91 from typing_extensions import Self 

92 

93 from airflow.sdk.bases.operatorlink import BaseOperatorLink 

94 from airflow.sdk.definitions.context import Context 

95 from airflow.sdk.definitions.dag import DAG 

96 from airflow.sdk.definitions.operator_resources import Resources 

97 from airflow.sdk.definitions.taskgroup import TaskGroup 

98 from airflow.sdk.definitions.xcom_arg import XComArg 

99 from airflow.serialization.enums import DagAttributeTypes 

100 from airflow.task.priority_strategy import PriorityWeightStrategy 

101 from airflow.triggers.base import BaseTrigger, StartTriggerArgs 

102 

103 TaskPreExecuteHook = Callable[[Context], None] 

104 TaskPostExecuteHook = Callable[[Context, Any], None] 

105 

106__all__ = [ 

107 "BaseOperator", 

108 "chain", 

109 "chain_linear", 

110 "cross_downstream", 

111] 

112 

113 

114class TriggerFailureReason(str, Enum): 

115 """ 

116 Reasons for trigger failures. 

117 

118 Internal use only. 

119 

120 :meta private: 

121 """ 

122 

123 TRIGGER_TIMEOUT = "Trigger timeout" 

124 TRIGGER_FAILURE = "Trigger failure" 

125 

126 

127TRIGGER_FAIL_REPR = "__fail__" 

128"""String value to represent trigger failure. 

129 

130Internal use only. 

131 

132:meta private: 

133""" 

134 

135 

136def _get_parent_defaults(dag: DAG | None, task_group: TaskGroup | None) -> tuple[dict, ParamsDict]: 

137 if not dag: 

138 return {}, ParamsDict() 

139 dag_args = copy.copy(dag.default_args) 

140 dag_params = copy.deepcopy(dag.params) 

141 dag_params._fill_missing_param_source("dag") 

142 if task_group: 

143 if task_group.default_args and not isinstance(task_group.default_args, collections.abc.Mapping): 

144 raise TypeError("default_args must be a mapping") 

145 dag_args.update(task_group.default_args) 

146 return dag_args, dag_params 

147 

148 

149def get_merged_defaults( 

150 dag: DAG | None, 

151 task_group: TaskGroup | None, 

152 task_params: collections.abc.MutableMapping | None, 

153 task_default_args: dict | None, 

154) -> tuple[dict, ParamsDict]: 

155 args, params = _get_parent_defaults(dag, task_group) 

156 if task_params: 

157 if not isinstance(task_params, collections.abc.Mapping): 

158 raise TypeError(f"params must be a mapping, got {type(task_params)}") 

159 

160 task_params = ParamsDict(task_params) 

161 task_params._fill_missing_param_source("task") 

162 params.update(task_params) 

163 

164 if task_default_args: 

165 if not isinstance(task_default_args, collections.abc.Mapping): 

166 raise TypeError(f"default_args must be a mapping, got {type(task_params)}") 

167 args.update(task_default_args) 

168 with contextlib.suppress(KeyError): 

169 if params_from_default_args := ParamsDict(task_default_args["params"] or {}): 

170 params_from_default_args._fill_missing_param_source("task") 

171 params.update(params_from_default_args) 

172 

173 return args, params 

174 

175 

176def parse_retries(retries: Any) -> int | None: 

177 if retries is None: 

178 return 0 

179 if type(retries) == int: # noqa: E721 

180 return retries 

181 try: 

182 parsed_retries = int(retries) 

183 except (TypeError, ValueError): 

184 raise RuntimeError(f"'retries' type must be int, not {type(retries).__name__}") 

185 return parsed_retries 

186 

187 

188def coerce_timedelta(value: float | timedelta, *, key: str | None = None) -> timedelta: 

189 if isinstance(value, timedelta): 

190 return value 

191 return timedelta(seconds=value) 

192 

193 

194def coerce_resources(resources: dict[str, Any] | None) -> Resources | None: 

195 if resources is None: 

196 return None 

197 from airflow.sdk.definitions.operator_resources import Resources 

198 

199 return Resources(**resources) 

200 

201 

202class _PartialDescriptor: 

203 """A descriptor that guards against ``.partial`` being called on Task objects.""" 

204 

205 class_method: ClassMethodDescriptorType | None = None 

206 

207 def __get__( 

208 self, obj: BaseOperator, cls: type[BaseOperator] | None = None 

209 ) -> Callable[..., OperatorPartial]: 

210 # Call this "partial" so it looks nicer in stack traces. 

211 def partial(**kwargs): 

212 raise TypeError("partial can only be called on Operator classes, not Tasks themselves") 

213 

214 if obj is not None: 

215 return partial 

216 return self.class_method.__get__(cls, cls) 

217 

218 

219OPERATOR_DEFAULTS: dict[str, Any] = { 

220 "allow_nested_operators": True, 

221 "depends_on_past": False, 

222 "email_on_failure": True, 

223 "email_on_retry": True, 

224 "execution_timeout": DEFAULT_TASK_EXECUTION_TIMEOUT, 

225 # "executor": DEFAULT_EXECUTOR, 

226 "executor_config": {}, 

227 "ignore_first_depends_on_past": DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST, 

228 "inlets": [], 

229 "map_index_template": None, 

230 "on_execute_callback": [], 

231 "on_failure_callback": [], 

232 "on_retry_callback": [], 

233 "on_skipped_callback": [], 

234 "on_success_callback": [], 

235 "outlets": [], 

236 "owner": DEFAULT_OWNER, 

237 "pool_slots": DEFAULT_POOL_SLOTS, 

238 "priority_weight": DEFAULT_PRIORITY_WEIGHT, 

239 "queue": DEFAULT_QUEUE, 

240 "retries": DEFAULT_RETRIES, 

241 "retry_delay": DEFAULT_RETRY_DELAY, 

242 "retry_exponential_backoff": 0, 

243 "trigger_rule": DEFAULT_TRIGGER_RULE, 

244 "wait_for_past_depends_before_skipping": DEFAULT_WAIT_FOR_PAST_DEPENDS_BEFORE_SKIPPING, 

245 "wait_for_downstream": False, 

246 "weight_rule": DEFAULT_WEIGHT_RULE, 

247} 

248 

249 

250# This is what handles the actual mapping. 

251 

252if TYPE_CHECKING: 

253 

254 def partial( 

255 operator_class: type[BaseOperator], 

256 *, 

257 task_id: str, 

258 dag: DAG | None = None, 

259 task_group: TaskGroup | None = None, 

260 start_date: datetime = ..., 

261 end_date: datetime = ..., 

262 owner: str = ..., 

263 email: None | str | Iterable[str] = ..., 

264 params: collections.abc.MutableMapping | None = None, 

265 resources: dict[str, Any] | None = ..., 

266 trigger_rule: str = ..., 

267 depends_on_past: bool = ..., 

268 ignore_first_depends_on_past: bool = ..., 

269 wait_for_past_depends_before_skipping: bool = ..., 

270 wait_for_downstream: bool = ..., 

271 retries: int | None = ..., 

272 queue: str = ..., 

273 pool: str = ..., 

274 pool_slots: int = ..., 

275 execution_timeout: timedelta | None = ..., 

276 max_retry_delay: None | timedelta | float = ..., 

277 retry_delay: timedelta | float = ..., 

278 retry_exponential_backoff: float = ..., 

279 priority_weight: int = ..., 

280 weight_rule: str | PriorityWeightStrategy = ..., 

281 sla: timedelta | None = ..., 

282 map_index_template: str | None = ..., 

283 max_active_tis_per_dag: int | None = ..., 

284 max_active_tis_per_dagrun: int | None = ..., 

285 on_execute_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = ..., 

286 on_failure_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = ..., 

287 on_success_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = ..., 

288 on_retry_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = ..., 

289 on_skipped_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = ..., 

290 run_as_user: str | None = ..., 

291 executor: str | None = ..., 

292 executor_config: dict | None = ..., 

293 inlets: Any | None = ..., 

294 outlets: Any | None = ..., 

295 doc: str | None = ..., 

296 doc_md: str | None = ..., 

297 doc_json: str | None = ..., 

298 doc_yaml: str | None = ..., 

299 doc_rst: str | None = ..., 

300 task_display_name: str | None = ..., 

301 logger_name: str | None = ..., 

302 allow_nested_operators: bool = True, 

303 **kwargs, 

304 ) -> OperatorPartial: ... 

305else: 

306 

307 def partial( 

308 operator_class: type[BaseOperator], 

309 *, 

310 task_id: str, 

311 dag: DAG | None = None, 

312 task_group: TaskGroup | None = None, 

313 params: collections.abc.MutableMapping | None = None, 

314 **kwargs, 

315 ): 

316 from airflow.sdk.definitions._internal.contextmanager import DagContext, TaskGroupContext 

317 

318 validate_mapping_kwargs(operator_class, "partial", kwargs) 

319 

320 dag = dag or DagContext.get_current() 

321 if dag: 

322 task_group = task_group or TaskGroupContext.get_current(dag) 

323 if task_group: 

324 task_id = task_group.child_id(task_id) 

325 

326 # Merge Dag and task group level defaults into user-supplied values. 

327 dag_default_args, partial_params = get_merged_defaults( 

328 dag=dag, 

329 task_group=task_group, 

330 task_params=params, 

331 task_default_args=kwargs.pop("default_args", None), 

332 ) 

333 

334 # Create partial_kwargs from args and kwargs 

335 partial_kwargs: dict[str, Any] = { 

336 "task_id": task_id, 

337 "dag": dag, 

338 "task_group": task_group, 

339 **kwargs, 

340 } 

341 

342 # Inject Dag-level default args into args provided to this function. 

343 # Most of the default args will be retrieved during unmapping; here we 

344 # only ensure base properties are correctly set for the scheduler. 

345 partial_kwargs.update( 

346 (k, v) 

347 for k, v in dag_default_args.items() 

348 if k not in partial_kwargs and k in BaseOperator.__init__._BaseOperatorMeta__param_names 

349 ) 

350 

351 # Fill fields not provided by the user with default values. 

352 partial_kwargs.update((k, v) for k, v in OPERATOR_DEFAULTS.items() if k not in partial_kwargs) 

353 

354 # Post-process arguments. Should be kept in sync with _TaskDecorator.expand(). 

355 if "task_concurrency" in kwargs: # Reject deprecated option. 

356 raise TypeError("unexpected argument: task_concurrency") 

357 if start_date := partial_kwargs.get("start_date", None): 

358 partial_kwargs["start_date"] = timezone.convert_to_utc(start_date) 

359 if end_date := partial_kwargs.get("end_date", None): 

360 partial_kwargs["end_date"] = timezone.convert_to_utc(end_date) 

361 if partial_kwargs["pool_slots"] < 1: 

362 dag_str = "" 

363 if dag: 

364 dag_str = f" in dag {dag.dag_id}" 

365 raise ValueError(f"pool slots for {task_id}{dag_str} cannot be less than 1") 

366 if retries := partial_kwargs.get("retries"): 

367 partial_kwargs["retries"] = BaseOperator._convert_retries(retries) 

368 partial_kwargs["retry_delay"] = BaseOperator._convert_retry_delay(partial_kwargs["retry_delay"]) 

369 partial_kwargs["max_retry_delay"] = BaseOperator._convert_max_retry_delay( 

370 partial_kwargs.get("max_retry_delay", None) 

371 ) 

372 

373 for k in ("execute", "failure", "success", "retry", "skipped"): 

374 partial_kwargs[attr] = _collect_from_input(partial_kwargs.get(attr := f"on_{k}_callback")) 

375 

376 return OperatorPartial( 

377 operator_class=operator_class, 

378 kwargs=partial_kwargs, 

379 params=partial_params, 

380 ) 

381 

382 

383class ExecutorSafeguard: 

384 """ 

385 The ExecutorSafeguard decorator. 

386 

387 Checks if the execute method of an operator isn't manually called outside 

388 the TaskInstance as we want to avoid bad mixing between decorated and 

389 classic operators. 

390 """ 

391 

392 test_mode: ClassVar[bool] = False 

393 tracker: ClassVar[ContextVar[BaseOperator]] = ContextVar("ExecutorSafeguard_sentinel") 

394 sentinel_value: ClassVar[object] = object() 

395 

396 @classmethod 

397 def decorator(cls, func): 

398 @wraps(func) 

399 def wrapper(self, *args, **kwargs): 

400 sentinel_key = f"{self.__class__.__name__}__sentinel" 

401 sentinel = kwargs.pop(sentinel_key, None) 

402 

403 with contextlib.ExitStack() as stack: 

404 if sentinel is cls.sentinel_value: 

405 token = cls.tracker.set(self) 

406 sentinel = self 

407 stack.callback(cls.tracker.reset, token) 

408 else: 

409 # No sentinel passed in, maybe the subclass execute did have it passed? 

410 sentinel = cls.tracker.get(None) 

411 

412 if not cls.test_mode and sentinel is not self: 

413 message = f"{self.__class__.__name__}.{func.__name__} cannot be called outside of the Task Runner!" 

414 if not self.allow_nested_operators: 

415 raise RuntimeError(message) 

416 self.log.warning(message) 

417 

418 # Now that we've logged, set sentinel so that `super()` calls don't log again 

419 token = cls.tracker.set(self) 

420 stack.callback(cls.tracker.reset, token) 

421 

422 return func(self, *args, **kwargs) 

423 

424 return wrapper 

425 

426 

427if "airflow.configuration" in sys.modules: 

428 # Don't try and import it if it's not already loaded 

429 from airflow.sdk.configuration import conf 

430 

431 ExecutorSafeguard.test_mode = conf.getboolean("core", "unit_test_mode") 

432 

433 

434def _collect_from_input(value_or_values: None | C | Collection[C]) -> list[C]: 

435 if not value_or_values: 

436 return [] 

437 if isinstance(value_or_values, Collection): 

438 return list(value_or_values) 

439 return [value_or_values] 

440 

441 

442class BaseOperatorMeta(abc.ABCMeta): 

443 """Metaclass of BaseOperator.""" 

444 

445 @classmethod 

446 def _apply_defaults(cls, func: T) -> T: 

447 """ 

448 Look for an argument named "default_args", and fill the unspecified arguments from it. 

449 

450 Since python2.* isn't clear about which arguments are missing when 

451 calling a function, and that this can be quite confusing with multi-level 

452 inheritance and argument defaults, this decorator also alerts with 

453 specific information about the missing arguments. 

454 """ 

455 # Cache inspect.signature for the wrapper closure to avoid calling it 

456 # at every decorated invocation. This is separate sig_cache created 

457 # per decoration, i.e. each function decorated using apply_defaults will 

458 # have a different sig_cache. 

459 sig_cache = inspect.signature(func) 

460 non_variadic_params = { 

461 name: param 

462 for (name, param) in sig_cache.parameters.items() 

463 if param.name != "self" and param.kind not in (param.VAR_POSITIONAL, param.VAR_KEYWORD) 

464 } 

465 non_optional_args = { 

466 name 

467 for name, param in non_variadic_params.items() 

468 if param.default == param.empty and name != "task_id" 

469 } 

470 

471 fixup_decorator_warning_stack(func) 

472 

473 @wraps(func) 

474 def apply_defaults(self: BaseOperator, *args: Any, **kwargs: Any) -> Any: 

475 from airflow.sdk.definitions._internal.contextmanager import DagContext, TaskGroupContext 

476 

477 if args: 

478 raise TypeError("Use keyword arguments when initializing operators") 

479 

480 instantiated_from_mapped = kwargs.pop( 

481 "_airflow_from_mapped", 

482 getattr(self, "_BaseOperator__from_mapped", False), 

483 ) 

484 

485 dag: DAG | None = kwargs.get("dag") 

486 if dag is None: 

487 dag = DagContext.get_current() 

488 if dag is not None: 

489 kwargs["dag"] = dag 

490 

491 task_group: TaskGroup | None = kwargs.get("task_group") 

492 if dag and not task_group: 

493 task_group = TaskGroupContext.get_current(dag) 

494 if task_group is not None: 

495 kwargs["task_group"] = task_group 

496 

497 default_args, merged_params = get_merged_defaults( 

498 dag=dag, 

499 task_group=task_group, 

500 task_params=kwargs.pop("params", None), 

501 task_default_args=kwargs.pop("default_args", None), 

502 ) 

503 

504 for arg in sig_cache.parameters: 

505 if arg not in kwargs and arg in default_args: 

506 kwargs[arg] = default_args[arg] 

507 

508 missing_args = non_optional_args.difference(kwargs) 

509 if len(missing_args) == 1: 

510 raise TypeError(f"missing keyword argument {missing_args.pop()!r}") 

511 if missing_args: 

512 display = ", ".join(repr(a) for a in sorted(missing_args)) 

513 raise TypeError(f"missing keyword arguments {display}") 

514 

515 if merged_params: 

516 kwargs["params"] = merged_params 

517 

518 hook = getattr(self, "_hook_apply_defaults", None) 

519 if hook: 

520 args, kwargs = hook(**kwargs, default_args=default_args) 

521 default_args = kwargs.pop("default_args", {}) 

522 

523 if not hasattr(self, "_BaseOperator__init_kwargs"): 

524 object.__setattr__(self, "_BaseOperator__init_kwargs", {}) 

525 object.__setattr__(self, "_BaseOperator__from_mapped", instantiated_from_mapped) 

526 

527 result = func(self, **kwargs, default_args=default_args) 

528 

529 # Store the args passed to init -- we need them to support task.map serialization! 

530 self._BaseOperator__init_kwargs.update(kwargs) # type: ignore 

531 

532 # Set upstream task defined by XComArgs passed to template fields of the operator. 

533 # BUT: only do this _ONCE_, not once for each class in the hierarchy 

534 if not instantiated_from_mapped and func == self.__init__.__wrapped__: # type: ignore[misc] 

535 self._set_xcomargs_dependencies() 

536 # Mark instance as instantiated so that future attr setting updates xcomarg-based deps. 

537 object.__setattr__(self, "_BaseOperator__instantiated", True) 

538 

539 return result 

540 

541 apply_defaults.__non_optional_args = non_optional_args # type: ignore 

542 apply_defaults.__param_names = set(non_variadic_params) # type: ignore 

543 

544 return cast("T", apply_defaults) 

545 

546 def __new__(cls, name, bases, namespace, **kwargs): 

547 execute_method = namespace.get("execute") 

548 if callable(execute_method) and not getattr(execute_method, "__isabstractmethod__", False): 

549 namespace["execute"] = ExecutorSafeguard.decorator(execute_method) 

550 new_cls = super().__new__(cls, name, bases, namespace, **kwargs) 

551 with contextlib.suppress(KeyError): 

552 # Update the partial descriptor with the class method, so it calls the actual function 

553 # (but let subclasses override it if they need to) 

554 partial_desc = vars(new_cls)["partial"] 

555 if isinstance(partial_desc, _PartialDescriptor): 

556 partial_desc.class_method = classmethod(partial) 

557 

558 # We patch `__init__` only if the class defines it. 

559 first_superclass = new_cls.mro()[1] 

560 if new_cls.__init__ is not first_superclass.__init__: 

561 new_cls.__init__ = cls._apply_defaults(new_cls.__init__) 

562 

563 return new_cls 

564 

565 

566# TODO: The following mapping is used to validate that the arguments passed to the BaseOperator are of the 

567# correct type. This is a temporary solution until we find a more sophisticated method for argument 

568# validation. One potential method is to use `get_type_hints` from the typing module. However, this is not 

569# fully compatible with future annotations for Python versions below 3.10. Once we require a minimum Python 

570# version that supports `get_type_hints` effectively or find a better approach, we can replace this 

571# manual type-checking method. 

572BASEOPERATOR_ARGS_EXPECTED_TYPES = { 

573 "task_id": str, 

574 "email": (str, Sequence), 

575 "email_on_retry": bool, 

576 "email_on_failure": bool, 

577 "retries": int, 

578 "retry_exponential_backoff": (int, float), 

579 "depends_on_past": bool, 

580 "ignore_first_depends_on_past": bool, 

581 "wait_for_past_depends_before_skipping": bool, 

582 "wait_for_downstream": bool, 

583 "priority_weight": int, 

584 "queue": str, 

585 "pool": str, 

586 "pool_slots": int, 

587 "trigger_rule": str, 

588 "run_as_user": str, 

589 "task_concurrency": int, 

590 "map_index_template": str, 

591 "max_active_tis_per_dag": int, 

592 "max_active_tis_per_dagrun": int, 

593 "executor": str, 

594 "do_xcom_push": bool, 

595 "multiple_outputs": bool, 

596 "doc": str, 

597 "doc_md": str, 

598 "doc_json": str, 

599 "doc_yaml": str, 

600 "doc_rst": str, 

601 "task_display_name": str, 

602 "logger_name": str, 

603 "allow_nested_operators": bool, 

604 "start_date": datetime, 

605 "end_date": datetime, 

606} 

607 

608 

609# Note: BaseOperator is defined as a dataclass, and not an attrs class as we do too much metaprogramming in 

610# here (metaclass, custom `__setattr__` behaviour) and this fights with attrs too much to make it worth it. 

611# 

612# To future reader: if you want to try and make this a "normal" attrs class, go ahead and attempt it. If you 

613# get nowhere leave your record here for the next poor soul and what problems you ran in to. 

614# 

615# @ashb, 2024/10/14 

616# - "Can't combine custom __setattr__ with on_setattr hooks" 

617# - Setting class-wide `define(on_setarrs=...)` isn't called for non-attrs subclasses 

618@total_ordering 

619@dataclass(repr=False) 

620class BaseOperator(AbstractOperator, metaclass=BaseOperatorMeta): 

621 r""" 

622 Abstract base class for all operators. 

623 

624 Since operators create objects that become nodes in the Dag, BaseOperator 

625 contains many recursive methods for Dag crawling behavior. To derive from 

626 this class, you are expected to override the constructor and the 'execute' 

627 method. 

628 

629 Operators derived from this class should perform or trigger certain tasks 

630 synchronously (wait for completion). Example of operators could be an 

631 operator that runs a Pig job (PigOperator), a sensor operator that 

632 waits for a partition to land in Hive (HiveSensorOperator), or one that 

633 moves data from Hive to MySQL (Hive2MySqlOperator). Instances of these 

634 operators (tasks) target specific operations, running specific scripts, 

635 functions or data transfers. 

636 

637 This class is abstract and shouldn't be instantiated. Instantiating a 

638 class derived from this one results in the creation of a task object, 

639 which ultimately becomes a node in Dag objects. Task dependencies should 

640 be set by using the set_upstream and/or set_downstream methods. 

641 

642 :param task_id: a unique, meaningful id for the task 

643 :param owner: the owner of the task. Using a meaningful description 

644 (e.g. user/person/team/role name) to clarify ownership is recommended. 

645 :param email: the 'to' email address(es) used in email alerts. This can be a 

646 single email or multiple ones. Multiple addresses can be specified as a 

647 comma or semicolon separated string or by passing a list of strings. (deprecated) 

648 :param email_on_retry: Indicates whether email alerts should be sent when a 

649 task is retried (deprecated) 

650 :param email_on_failure: Indicates whether email alerts should be sent when 

651 a task failed (deprecated) 

652 :param retries: the number of retries that should be performed before 

653 failing the task 

654 :param retry_delay: delay between retries, can be set as ``timedelta`` or 

655 ``float`` seconds, which will be converted into ``timedelta``, 

656 the default is ``timedelta(seconds=300)``. 

657 :param retry_exponential_backoff: multiplier for exponential backoff between retries. 

658 Set to 0 to disable (constant delay). Set to 2.0 for standard exponential backoff 

659 (delay doubles with each retry). For example, with retry_delay=4min and 

660 retry_exponential_backoff=5, retries occur after 4min, 20min, 100min, etc. 

661 :param max_retry_delay: maximum delay interval between retries, can be set as 

662 ``timedelta`` or ``float`` seconds, which will be converted into ``timedelta``. 

663 :param start_date: The ``start_date`` for the task, determines 

664 the ``logical_date`` for the first task instance. The best practice 

665 is to have the start_date rounded 

666 to your Dag's ``schedule_interval``. Daily jobs have their start_date 

667 some day at 00:00:00, hourly jobs have their start_date at 00:00 

668 of a specific hour. Note that Airflow simply looks at the latest 

669 ``logical_date`` and adds the ``schedule_interval`` to determine 

670 the next ``logical_date``. It is also very important 

671 to note that different tasks' dependencies 

672 need to line up in time. If task A depends on task B and their 

673 start_date are offset in a way that their logical_date don't line 

674 up, A's dependencies will never be met. If you are looking to delay 

675 a task, for example running a daily task at 2AM, look into the 

676 ``TimeSensor`` and ``TimeDeltaSensor``. We advise against using 

677 dynamic ``start_date`` and recommend using fixed ones. Read the 

678 FAQ entry about start_date for more information. 

679 :param end_date: if specified, the scheduler won't go beyond this date 

680 :param depends_on_past: when set to true, task instances will run 

681 sequentially and only if the previous instance has succeeded or has been skipped. 

682 The task instance for the start_date is allowed to run. 

683 :param wait_for_past_depends_before_skipping: when set to true, if the task instance 

684 should be marked as skipped, and depends_on_past is true, the ti will stay on None state 

685 waiting the task of the previous run 

686 :param wait_for_downstream: when set to true, an instance of task 

687 X will wait for tasks immediately downstream of the previous instance 

688 of task X to finish successfully or be skipped before it runs. This is useful if the 

689 different instances of a task X alter the same asset, and this asset 

690 is used by tasks downstream of task X. Note that depends_on_past 

691 is forced to True wherever wait_for_downstream is used. Also note that 

692 only tasks *immediately* downstream of the previous task instance are waited 

693 for; the statuses of any tasks further downstream are ignored. 

694 :param dag: a reference to the dag the task is attached to (if any) 

695 :param priority_weight: priority weight of this task against other task. 

696 This allows the executor to trigger higher priority tasks before 

697 others when things get backed up. Set priority_weight as a higher 

698 number for more important tasks. 

699 As not all database engines support 64-bit integers, values are capped with 32-bit. 

700 Valid range is from -2,147,483,648 to 2,147,483,647. 

701 :param weight_rule: weighting method used for the effective total 

702 priority weight of the task. Options are: 

703 ``{ downstream | upstream | absolute }`` default is ``downstream`` 

704 When set to ``downstream`` the effective weight of the task is the 

705 aggregate sum of all downstream descendants. As a result, upstream 

706 tasks will have higher weight and will be scheduled more aggressively 

707 when using positive weight values. This is useful when you have 

708 multiple dag run instances and desire to have all upstream tasks to 

709 complete for all runs before each dag can continue processing 

710 downstream tasks. When set to ``upstream`` the effective weight is the 

711 aggregate sum of all upstream ancestors. This is the opposite where 

712 downstream tasks have higher weight and will be scheduled more 

713 aggressively when using positive weight values. This is useful when you 

714 have multiple dag run instances and prefer to have each dag complete 

715 before starting upstream tasks of other dags. When set to 

716 ``absolute``, the effective weight is the exact ``priority_weight`` 

717 specified without additional weighting. You may want to do this when 

718 you know exactly what priority weight each task should have. 

719 Additionally, when set to ``absolute``, there is bonus effect of 

720 significantly speeding up the task creation process as for very large 

721 Dags. Options can be set as string or using the constants defined in 

722 the static class ``airflow.utils.WeightRule``. 

723 Irrespective of the weight rule, resulting priority values are capped with 32-bit. 

724 |experimental| 

725 Since 2.9.0, Airflow allows to define custom priority weight strategy, 

726 by creating a subclass of 

727 ``airflow.task.priority_strategy.PriorityWeightStrategy`` and registering 

728 in a plugin, then providing the class path or the class instance via 

729 ``weight_rule`` parameter. The custom priority weight strategy will be 

730 used to calculate the effective total priority weight of the task instance. 

731 :param queue: which queue to target when running this job. Not 

732 all executors implement queue management, the CeleryExecutor 

733 does support targeting specific queues. 

734 :param pool: the slot pool this task should run in, slot pools are a 

735 way to limit concurrency for certain tasks 

736 :param pool_slots: the number of pool slots this task should use (>= 1) 

737 Values less than 1 are not allowed. 

738 :param sla: DEPRECATED - The SLA feature is removed in Airflow 3.0, to be replaced with a 

739 new implementation in Airflow >=3.1. 

740 :param execution_timeout: max time allowed for the execution of 

741 this task instance, if it goes beyond it will raise and fail. 

742 :param on_failure_callback: a function or list of functions to be called when a task instance 

743 of this task fails. a context dictionary is passed as a single 

744 parameter to this function. Context contains references to related 

745 objects to the task instance and is documented under the macros 

746 section of the API. 

747 :param on_execute_callback: much like the ``on_failure_callback`` except 

748 that it is executed right before the task is executed. 

749 :param on_retry_callback: much like the ``on_failure_callback`` except 

750 that it is executed when retries occur. 

751 :param on_success_callback: much like the ``on_failure_callback`` except 

752 that it is executed when the task succeeds. 

753 :param on_skipped_callback: much like the ``on_failure_callback`` except 

754 that it is executed when skipped occur; this callback will be called only if AirflowSkipException get raised. 

755 Explicitly it is NOT called if a task is not started to be executed because of a preceding branching 

756 decision in the Dag or a trigger rule which causes execution to skip so that the task execution 

757 is never scheduled. 

758 :param pre_execute: a function to be called immediately before task 

759 execution, receiving a context dictionary; raising an exception will 

760 prevent the task from being executed. 

761 :param post_execute: a function to be called immediately after task 

762 execution, receiving a context dictionary and task result; raising an 

763 exception will prevent the task from succeeding. 

764 :param trigger_rule: defines the rule by which dependencies are applied 

765 for the task to get triggered. Options are: 

766 ``{ all_success | all_failed | all_done | all_skipped | one_success | one_done | 

767 one_failed | none_failed | none_failed_min_one_success | none_skipped | always}`` 

768 default is ``all_success``. Options can be set as string or 

769 using the constants defined in the static class 

770 ``airflow.utils.TriggerRule`` 

771 :param resources: A map of resource parameter names (the argument names of the 

772 Resources constructor) to their values. 

773 :param run_as_user: unix username to impersonate while running the task 

774 :param max_active_tis_per_dag: When set, a task will be able to limit the concurrent 

775 runs across logical_dates. 

776 :param max_active_tis_per_dagrun: When set, a task will be able to limit the concurrent 

777 task instances per Dag run. 

778 :param executor: Which executor to target when running this task. NOT YET SUPPORTED 

779 :param executor_config: Additional task-level configuration parameters that are 

780 interpreted by a specific executor. Parameters are namespaced by the name of 

781 executor. 

782 

783 **Example**: to run this task in a specific docker container through 

784 the KubernetesExecutor :: 

785 

786 MyOperator(..., executor_config={"KubernetesExecutor": {"image": "myCustomDockerImage"}}) 

787 

788 :param do_xcom_push: if True, an XCom is pushed containing the Operator's 

789 result 

790 :param multiple_outputs: if True and do_xcom_push is True, pushes multiple XComs, one for each 

791 key in the returned dictionary result. If False and do_xcom_push is True, pushes a single XCom. 

792 :param task_group: The TaskGroup to which the task should belong. This is typically provided when not 

793 using a TaskGroup as a context manager. 

794 :param doc: Add documentation or notes to your Task objects that is visible in 

795 Task Instance details View in the Webserver 

796 :param doc_md: Add documentation (in Markdown format) or notes to your Task objects 

797 that is visible in Task Instance details View in the Webserver 

798 :param doc_rst: Add documentation (in RST format) or notes to your Task objects 

799 that is visible in Task Instance details View in the Webserver 

800 :param doc_json: Add documentation (in JSON format) or notes to your Task objects 

801 that is visible in Task Instance details View in the Webserver 

802 :param doc_yaml: Add documentation (in YAML format) or notes to your Task objects 

803 that is visible in Task Instance details View in the Webserver 

804 :param task_display_name: The display name of the task which appears on the UI. 

805 :param logger_name: Name of the logger used by the Operator to emit logs. 

806 If set to `None` (default), the logger name will fall back to 

807 `airflow.task.operators.{class.__module__}.{class.__name__}` (e.g. HttpOperator will have 

808 *airflow.task.operators.airflow.providers.http.operators.http.HttpOperator* as logger). 

809 :param allow_nested_operators: if True, when an operator is executed within another one a warning message 

810 will be logged. If False, then an exception will be raised if the operator is badly used (e.g. nested 

811 within another one). In future releases of Airflow this parameter will be removed and an exception 

812 will always be thrown when operators are nested within each other (default is True). 

813 

814 **Example**: example of a bad operator mixin usage:: 

815 

816 @task(provide_context=True) 

817 def say_hello_world(**context): 

818 hello_world_task = BashOperator( 

819 task_id="hello_world_task", 

820 bash_command="python -c \"print('Hello, world!')\"", 

821 dag=dag, 

822 ) 

823 hello_world_task.execute(context) 

824 """ 

825 

826 task_id: str 

827 owner: str = DEFAULT_OWNER 

828 email: str | Sequence[str] | None = None 

829 email_on_retry: bool = True 

830 email_on_failure: bool = True 

831 retries: int | None = DEFAULT_RETRIES 

832 retry_delay: timedelta = DEFAULT_RETRY_DELAY 

833 retry_exponential_backoff: float = 0 

834 max_retry_delay: timedelta | float | None = None 

835 start_date: datetime | None = None 

836 end_date: datetime | None = None 

837 depends_on_past: bool = False 

838 ignore_first_depends_on_past: bool = DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST 

839 wait_for_past_depends_before_skipping: bool = DEFAULT_WAIT_FOR_PAST_DEPENDS_BEFORE_SKIPPING 

840 wait_for_downstream: bool = False 

841 

842 # At execution_time this becomes a normal dict 

843 params: ParamsDict | dict = field(default_factory=ParamsDict) 

844 default_args: dict | None = None 

845 priority_weight: int = DEFAULT_PRIORITY_WEIGHT 

846 weight_rule: PriorityWeightStrategy = field( 

847 default_factory=airflow_priority_weight_strategies[DEFAULT_WEIGHT_RULE] 

848 ) 

849 queue: str = DEFAULT_QUEUE 

850 pool: str = DEFAULT_POOL_NAME 

851 pool_slots: int = DEFAULT_POOL_SLOTS 

852 execution_timeout: timedelta | None = DEFAULT_TASK_EXECUTION_TIMEOUT 

853 on_execute_callback: Sequence[TaskStateChangeCallback] = () 

854 on_failure_callback: Sequence[TaskStateChangeCallback] = () 

855 on_success_callback: Sequence[TaskStateChangeCallback] = () 

856 on_retry_callback: Sequence[TaskStateChangeCallback] = () 

857 on_skipped_callback: Sequence[TaskStateChangeCallback] = () 

858 _pre_execute_hook: TaskPreExecuteHook | None = None 

859 _post_execute_hook: TaskPostExecuteHook | None = None 

860 trigger_rule: TriggerRule = DEFAULT_TRIGGER_RULE 

861 resources: dict[str, Any] | None = None 

862 run_as_user: str | None = None 

863 task_concurrency: int | None = None 

864 map_index_template: str | None = None 

865 max_active_tis_per_dag: int | None = None 

866 max_active_tis_per_dagrun: int | None = None 

867 executor: str | None = None 

868 executor_config: dict | None = None 

869 do_xcom_push: bool = True 

870 multiple_outputs: bool = False 

871 inlets: list[Any] = field(default_factory=list) 

872 outlets: list[Any] = field(default_factory=list) 

873 task_group: TaskGroup | None = None 

874 doc: str | None = None 

875 doc_md: str | None = None 

876 doc_json: str | None = None 

877 doc_yaml: str | None = None 

878 doc_rst: str | None = None 

879 _task_display_name: str | None = None 

880 logger_name: str | None = None 

881 allow_nested_operators: bool = True 

882 

883 is_setup: bool = False 

884 is_teardown: bool = False 

885 

886 # TODO: Task-SDK: Make these ClassVar[]? 

887 template_fields: Collection[str] = () 

888 template_ext: Sequence[str] = () 

889 

890 template_fields_renderers: ClassVar[dict[str, str]] = {} 

891 

892 operator_extra_links: Collection[BaseOperatorLink] = () 

893 

894 # Defines the color in the UI 

895 ui_color: str = "#fff" 

896 ui_fgcolor: str = "#000" 

897 

898 partial: Callable[..., OperatorPartial] = _PartialDescriptor() # type: ignore 

899 

900 _dag: DAG | None = field(init=False, default=None) 

901 

902 # Make this optional so the type matches the one define in LoggingMixin 

903 _log_config_logger_name: str | None = field(default="airflow.task.operators", init=False) 

904 _logger_name: str | None = None 

905 

906 # The _serialized_fields are lazily loaded when get_serialized_fields() method is called 

907 __serialized_fields: ClassVar[frozenset[str] | None] = None 

908 

909 _comps: ClassVar[set[str]] = { 

910 "task_id", 

911 "dag_id", 

912 "owner", 

913 "email", 

914 "email_on_retry", 

915 "retry_delay", 

916 "retry_exponential_backoff", 

917 "max_retry_delay", 

918 "start_date", 

919 "end_date", 

920 "depends_on_past", 

921 "wait_for_downstream", 

922 "priority_weight", 

923 "execution_timeout", 

924 "has_on_execute_callback", 

925 "has_on_failure_callback", 

926 "has_on_success_callback", 

927 "has_on_retry_callback", 

928 "has_on_skipped_callback", 

929 "do_xcom_push", 

930 "multiple_outputs", 

931 "allow_nested_operators", 

932 "executor", 

933 } 

934 

935 # If True, the Rendered Template fields will be overwritten in DB after execution 

936 # This is useful for Taskflow decorators that modify the template fields during execution like 

937 # @task.bash decorator. 

938 overwrite_rtif_after_execution: bool = False 

939 

940 # If True then the class constructor was called 

941 __instantiated: bool = False 

942 # List of args as passed to `init()`, after apply_defaults() has been updated. Used to "recreate" the task 

943 # when mapping 

944 # Set via the metaclass 

945 __init_kwargs: dict[str, Any] = field(init=False) 

946 

947 # Set to True before calling execute method 

948 _lock_for_execution: bool = False 

949 

950 # Set to True for an operator instantiated by a mapped operator. 

951 __from_mapped: bool = False 

952 

953 start_trigger_args: StartTriggerArgs | None = None 

954 start_from_trigger: bool = False 

955 

956 # base list which includes all the attrs that don't need deep copy. 

957 _base_operator_shallow_copy_attrs: Final[tuple[str, ...]] = ( 

958 "user_defined_macros", 

959 "user_defined_filters", 

960 "params", 

961 ) 

962 

963 # each operator should override this class attr for shallow copy attrs. 

964 shallow_copy_attrs: Sequence[str] = () 

965 

966 def __setattr__(self: BaseOperator, key: str, value: Any): 

967 if converter := getattr(self, f"_convert_{key}", None): 

968 value = converter(value) 

969 super().__setattr__(key, value) 

970 if self.__from_mapped or self._lock_for_execution: 

971 return # Skip any custom behavior for validation and during execute. 

972 if key in self.__init_kwargs: 

973 self.__init_kwargs[key] = value 

974 if self.__instantiated and key in self.template_fields: 

975 # Resolve upstreams set by assigning an XComArg after initializing 

976 # an operator, example: 

977 # op = BashOperator() 

978 # op.bash_command = "sleep 1" 

979 self._set_xcomargs_dependency(key, value) 

980 

981 def __init__( 

982 self, 

983 *, 

984 task_id: str, 

985 owner: str = DEFAULT_OWNER, 

986 email: str | Sequence[str] | None = None, 

987 email_on_retry: bool = True, 

988 email_on_failure: bool = True, 

989 retries: int | None = DEFAULT_RETRIES, 

990 retry_delay: timedelta | float = DEFAULT_RETRY_DELAY, 

991 retry_exponential_backoff: float = 0, 

992 max_retry_delay: timedelta | float | None = None, 

993 start_date: datetime | None = None, 

994 end_date: datetime | None = None, 

995 depends_on_past: bool = False, 

996 ignore_first_depends_on_past: bool = DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST, 

997 wait_for_past_depends_before_skipping: bool = DEFAULT_WAIT_FOR_PAST_DEPENDS_BEFORE_SKIPPING, 

998 wait_for_downstream: bool = False, 

999 dag: DAG | None = None, 

1000 params: collections.abc.MutableMapping[str, Any] | None = None, 

1001 default_args: dict | None = None, 

1002 priority_weight: int = DEFAULT_PRIORITY_WEIGHT, 

1003 weight_rule: str | PriorityWeightStrategy = DEFAULT_WEIGHT_RULE, 

1004 queue: str = DEFAULT_QUEUE, 

1005 pool: str | None = None, 

1006 pool_slots: int = DEFAULT_POOL_SLOTS, 

1007 sla: timedelta | None = None, 

1008 execution_timeout: timedelta | None = DEFAULT_TASK_EXECUTION_TIMEOUT, 

1009 on_execute_callback: None | TaskStateChangeCallback | Collection[TaskStateChangeCallback] = None, 

1010 on_failure_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = None, 

1011 on_success_callback: None | TaskStateChangeCallback | Collection[TaskStateChangeCallback] = None, 

1012 on_retry_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = None, 

1013 on_skipped_callback: None | TaskStateChangeCallback | Collection[TaskStateChangeCallback] = None, 

1014 pre_execute: TaskPreExecuteHook | None = None, 

1015 post_execute: TaskPostExecuteHook | None = None, 

1016 trigger_rule: str = DEFAULT_TRIGGER_RULE, 

1017 resources: dict[str, Any] | None = None, 

1018 run_as_user: str | None = None, 

1019 map_index_template: str | None = None, 

1020 max_active_tis_per_dag: int | None = None, 

1021 max_active_tis_per_dagrun: int | None = None, 

1022 executor: str | None = None, 

1023 executor_config: dict | None = None, 

1024 do_xcom_push: bool = True, 

1025 multiple_outputs: bool = False, 

1026 inlets: Any | None = None, 

1027 outlets: Any | None = None, 

1028 task_group: TaskGroup | None = None, 

1029 doc: str | None = None, 

1030 doc_md: str | None = None, 

1031 doc_json: str | None = None, 

1032 doc_yaml: str | None = None, 

1033 doc_rst: str | None = None, 

1034 task_display_name: str | None = None, 

1035 logger_name: str | None = None, 

1036 allow_nested_operators: bool = True, 

1037 **kwargs: Any, 

1038 ): 

1039 # Note: Metaclass handles passing in the Dag/TaskGroup from active context manager, if any 

1040 

1041 # Only apply task_group prefix if this operator was not created from a mapped operator 

1042 # Mapped operators already have the prefix applied during their creation 

1043 if task_group and not self.__from_mapped: 

1044 self.task_id = task_group.child_id(task_id) 

1045 task_group.add(self) 

1046 else: 

1047 self.task_id = task_id 

1048 

1049 super().__init__() 

1050 self.task_group = task_group 

1051 

1052 kwargs.pop("_airflow_mapped_validation_only", None) 

1053 if kwargs: 

1054 raise TypeError( 

1055 f"Invalid arguments were passed to {self.__class__.__name__} (task_id: {task_id}). " 

1056 f"Invalid arguments were:\n**kwargs: {redact(kwargs)}", 

1057 ) 

1058 validate_key(self.task_id) 

1059 

1060 self.owner = owner 

1061 self.email = email 

1062 self.email_on_retry = email_on_retry 

1063 self.email_on_failure = email_on_failure 

1064 

1065 if email is not None: 

1066 warnings.warn( 

1067 "Setting email on a task is deprecated; please migrate to SmtpNotifier.", 

1068 RemovedInAirflow4Warning, 

1069 stacklevel=2, 

1070 ) 

1071 if email and email_on_retry is not None: 

1072 warnings.warn( 

1073 "Setting email_on_retry on a task is deprecated; please migrate to SmtpNotifier.", 

1074 RemovedInAirflow4Warning, 

1075 stacklevel=2, 

1076 ) 

1077 if email and email_on_failure is not None: 

1078 warnings.warn( 

1079 "Setting email_on_failure on a task is deprecated; please migrate to SmtpNotifier.", 

1080 RemovedInAirflow4Warning, 

1081 stacklevel=2, 

1082 ) 

1083 

1084 if execution_timeout is not None and not isinstance(execution_timeout, timedelta): 

1085 raise ValueError( 

1086 f"execution_timeout must be timedelta object but passed as type: {type(execution_timeout)}" 

1087 ) 

1088 self.execution_timeout = execution_timeout 

1089 

1090 self.on_execute_callback = _collect_from_input(on_execute_callback) 

1091 self.on_failure_callback = _collect_from_input(on_failure_callback) 

1092 self.on_success_callback = _collect_from_input(on_success_callback) 

1093 self.on_retry_callback = _collect_from_input(on_retry_callback) 

1094 self.on_skipped_callback = _collect_from_input(on_skipped_callback) 

1095 self._pre_execute_hook = pre_execute 

1096 self._post_execute_hook = post_execute 

1097 

1098 self.start_date = timezone.convert_to_utc(start_date) 

1099 self.end_date = timezone.convert_to_utc(end_date) 

1100 self.executor = executor 

1101 self.executor_config = executor_config or {} 

1102 self.run_as_user = run_as_user 

1103 # TODO: 

1104 # self.retries = parse_retries(retries) 

1105 self.retries = retries 

1106 self.queue = queue 

1107 self.pool = DEFAULT_POOL_NAME if pool is None else pool 

1108 self.pool_slots = pool_slots 

1109 if self.pool_slots < 1: 

1110 dag_str = f" in dag {dag.dag_id}" if dag else "" 

1111 raise ValueError(f"pool slots for {self.task_id}{dag_str} cannot be less than 1") 

1112 if sla is not None: 

1113 warnings.warn( 

1114 "The SLA feature is removed in Airflow 3.0, replaced with Deadline Alerts in >=3.1", 

1115 stacklevel=2, 

1116 ) 

1117 

1118 try: 

1119 TriggerRule(trigger_rule) 

1120 except ValueError: 

1121 raise ValueError( 

1122 f"The trigger_rule must be one of {[rule.value for rule in TriggerRule]}," 

1123 f"'{dag.dag_id if dag else ''}.{task_id}'; received '{trigger_rule}'." 

1124 ) 

1125 

1126 self.trigger_rule: TriggerRule = TriggerRule(trigger_rule) 

1127 

1128 self.depends_on_past: bool = depends_on_past 

1129 self.ignore_first_depends_on_past: bool = ignore_first_depends_on_past 

1130 self.wait_for_past_depends_before_skipping: bool = wait_for_past_depends_before_skipping 

1131 self.wait_for_downstream: bool = wait_for_downstream 

1132 if wait_for_downstream: 

1133 self.depends_on_past = True 

1134 

1135 # Converted by setattr 

1136 self.retry_delay = retry_delay # type: ignore[assignment] 

1137 self.retry_exponential_backoff = retry_exponential_backoff 

1138 if max_retry_delay is not None: 

1139 self.max_retry_delay = max_retry_delay 

1140 

1141 self.resources = resources 

1142 

1143 self.params = ParamsDict(params) 

1144 

1145 self.priority_weight = priority_weight 

1146 self.weight_rule = validate_and_load_priority_weight_strategy(weight_rule) 

1147 

1148 self.max_active_tis_per_dag: int | None = max_active_tis_per_dag 

1149 self.max_active_tis_per_dagrun: int | None = max_active_tis_per_dagrun 

1150 self.do_xcom_push: bool = do_xcom_push 

1151 self.map_index_template: str | None = map_index_template 

1152 self.multiple_outputs: bool = multiple_outputs 

1153 

1154 self.doc_md = doc_md 

1155 self.doc_json = doc_json 

1156 self.doc_yaml = doc_yaml 

1157 self.doc_rst = doc_rst 

1158 self.doc = doc 

1159 

1160 self._task_display_name = task_display_name 

1161 

1162 self.allow_nested_operators = allow_nested_operators 

1163 

1164 self._logger_name = logger_name 

1165 

1166 # Lineage 

1167 self.inlets = _collect_from_input(inlets) 

1168 self.outlets = _collect_from_input(outlets) 

1169 

1170 if isinstance(self.template_fields, str): 

1171 warnings.warn( 

1172 f"The `template_fields` value for {self.task_type} is a string " 

1173 "but should be a list or tuple of string. Wrapping it in a list for execution. " 

1174 f"Please update {self.task_type} accordingly.", 

1175 UserWarning, 

1176 stacklevel=2, 

1177 ) 

1178 self.template_fields = [self.template_fields] 

1179 

1180 self.is_setup = False 

1181 self.is_teardown = False 

1182 

1183 if SetupTeardownContext.active: 

1184 SetupTeardownContext.update_context_map(self) 

1185 

1186 # We set self.dag right at the end as `_convert_dag` calls `dag.add_task` for us, and we need all the 

1187 # other properties to be set at that point 

1188 if dag is not None: 

1189 self.dag = dag 

1190 

1191 validate_instance_args(self, BASEOPERATOR_ARGS_EXPECTED_TYPES) 

1192 

1193 # Ensure priority_weight is within the valid range 

1194 self.priority_weight = db_safe_priority(self.priority_weight) 

1195 

1196 def __eq__(self, other): 

1197 if type(self) is type(other): 

1198 # Use getattr() instead of __dict__ as __dict__ doesn't return 

1199 # correct values for properties. 

1200 return all(getattr(self, c, None) == getattr(other, c, None) for c in self._comps) 

1201 return False 

1202 

1203 def __ne__(self, other): 

1204 return not self == other 

1205 

1206 def __hash__(self): 

1207 hash_components = [type(self)] 

1208 for component in self._comps: 

1209 val = getattr(self, component, None) 

1210 try: 

1211 hash(val) 

1212 hash_components.append(val) 

1213 except TypeError: 

1214 hash_components.append(repr(val)) 

1215 return hash(tuple(hash_components)) 

1216 

1217 # /Composing Operators --------------------------------------------- 

1218 

1219 def __gt__(self, other): 

1220 """ 

1221 Return [Operator] > [Outlet]. 

1222 

1223 If other is an attr annotated object it is set as an outlet of this Operator. 

1224 """ 

1225 if not isinstance(other, Iterable): 

1226 other = [other] 

1227 

1228 for obj in other: 

1229 if not attrs.has(obj): 

1230 raise TypeError(f"Left hand side ({obj}) is not an outlet") 

1231 self.add_outlets(other) 

1232 

1233 return self 

1234 

1235 def __lt__(self, other): 

1236 """ 

1237 Return [Inlet] > [Operator] or [Operator] < [Inlet]. 

1238 

1239 If other is an attr annotated object it is set as an inlet to this operator. 

1240 """ 

1241 if not isinstance(other, Iterable): 

1242 other = [other] 

1243 

1244 for obj in other: 

1245 if not attrs.has(obj): 

1246 raise TypeError(f"{obj} cannot be an inlet") 

1247 self.add_inlets(other) 

1248 

1249 return self 

1250 

1251 def __deepcopy__(self, memo: dict[int, Any]): 

1252 # Hack sorting double chained task lists by task_id to avoid hitting 

1253 # max_depth on deepcopy operations. 

1254 sys.setrecursionlimit(5000) # TODO fix this in a better way 

1255 

1256 cls = self.__class__ 

1257 result = cls.__new__(cls) 

1258 memo[id(self)] = result 

1259 

1260 shallow_copy = tuple(cls.shallow_copy_attrs) + cls._base_operator_shallow_copy_attrs 

1261 

1262 for k, v_org in self.__dict__.items(): 

1263 if k not in shallow_copy: 

1264 v = copy.deepcopy(v_org, memo) 

1265 else: 

1266 v = copy.copy(v_org) 

1267 

1268 # Bypass any setters, and set it on the object directly. This works since we are cloning ourself so 

1269 # we know the type is already fine 

1270 result.__dict__[k] = v 

1271 return result 

1272 

1273 def __getstate__(self): 

1274 state = dict(self.__dict__) 

1275 if "_log" in state: 

1276 del state["_log"] 

1277 

1278 return state 

1279 

1280 def __setstate__(self, state): 

1281 self.__dict__ = state 

1282 

1283 def add_inlets(self, inlets: Iterable[Any]): 

1284 """Set inlets to this operator.""" 

1285 self.inlets.extend(inlets) 

1286 

1287 def add_outlets(self, outlets: Iterable[Any]): 

1288 """Define the outlets of this operator.""" 

1289 self.outlets.extend(outlets) 

1290 

1291 def get_dag(self) -> DAG | None: 

1292 return self._dag 

1293 

1294 @property 

1295 def dag(self) -> DAG: 

1296 """Returns the Operator's Dag if set, otherwise raises an error.""" 

1297 if dag := self._dag: 

1298 return dag 

1299 raise RuntimeError(f"Operator {self} has not been assigned to a Dag yet") 

1300 

1301 @dag.setter 

1302 def dag(self, dag: DAG | None) -> None: 

1303 """Operators can be assigned to one Dag, one time. Repeat assignments to that same Dag are ok.""" 

1304 self._dag = dag 

1305 

1306 def _convert__dag(self, dag: DAG | None) -> DAG | None: 

1307 # Called automatically by __setattr__ method 

1308 from airflow.sdk.definitions.dag import DAG 

1309 

1310 if dag is None: 

1311 return dag 

1312 

1313 if not isinstance(dag, DAG): 

1314 raise TypeError(f"Expected dag; received {dag.__class__.__name__}") 

1315 if self._dag is not None and self._dag is not dag: 

1316 raise ValueError(f"The dag assigned to {self} can not be changed.") 

1317 

1318 if self.__from_mapped: 

1319 pass # Don't add to dag -- the mapped task takes the place. 

1320 elif dag.task_dict.get(self.task_id) is not self: 

1321 dag.add_task(self) 

1322 return dag 

1323 

1324 @staticmethod 

1325 def _convert_retries(retries: Any) -> int | None: 

1326 if retries is None: 

1327 return 0 

1328 if type(retries) == int: # noqa: E721 

1329 return retries 

1330 try: 

1331 parsed_retries = int(retries) 

1332 except (TypeError, ValueError): 

1333 raise TypeError(f"'retries' type must be int, not {type(retries).__name__}") 

1334 return parsed_retries 

1335 

1336 @staticmethod 

1337 def _convert_timedelta(value: float | timedelta | None) -> timedelta | None: 

1338 if value is None or isinstance(value, timedelta): 

1339 return value 

1340 return timedelta(seconds=value) 

1341 

1342 _convert_retry_delay = _convert_timedelta 

1343 _convert_max_retry_delay = _convert_timedelta 

1344 

1345 @staticmethod 

1346 def _convert_resources(resources: dict[str, Any] | None) -> Resources | None: 

1347 if resources is None: 

1348 return None 

1349 

1350 from airflow.sdk.definitions.operator_resources import Resources 

1351 

1352 if isinstance(resources, Resources): 

1353 return resources 

1354 

1355 return Resources(**resources) 

1356 

1357 def _convert_is_setup(self, value: bool) -> bool: 

1358 """ 

1359 Setter for is_setup property. 

1360 

1361 :meta private: 

1362 """ 

1363 if self.is_teardown and value: 

1364 raise ValueError(f"Cannot mark task '{self.task_id}' as setup; task is already a teardown.") 

1365 return value 

1366 

1367 def _convert_is_teardown(self, value: bool) -> bool: 

1368 if self.is_setup and value: 

1369 raise ValueError(f"Cannot mark task '{self.task_id}' as teardown; task is already a setup.") 

1370 return value 

1371 

1372 @property 

1373 def task_display_name(self) -> str: 

1374 return self._task_display_name or self.task_id 

1375 

1376 def has_dag(self): 

1377 """Return True if the Operator has been assigned to a Dag.""" 

1378 return self._dag is not None 

1379 

1380 def _set_xcomargs_dependencies(self) -> None: 

1381 from airflow.sdk.definitions.xcom_arg import XComArg 

1382 

1383 for f in self.template_fields: 

1384 arg = getattr(self, f, NOTSET) 

1385 if arg is not NOTSET: 

1386 XComArg.apply_upstream_relationship(self, arg) 

1387 

1388 def _set_xcomargs_dependency(self, field: str, newvalue: Any) -> None: 

1389 """ 

1390 Resolve upstream dependencies of a task. 

1391 

1392 In this way passing an ``XComArg`` as value for a template field 

1393 will result in creating upstream relation between two tasks. 

1394 

1395 **Example**: :: 

1396 

1397 with DAG(...): 

1398 generate_content = GenerateContentOperator(task_id="generate_content") 

1399 send_email = EmailOperator(..., html_content=generate_content.output) 

1400 

1401 # This is equivalent to 

1402 with DAG(...): 

1403 generate_content = GenerateContentOperator(task_id="generate_content") 

1404 send_email = EmailOperator( 

1405 ..., html_content="{{ task_instance.xcom_pull('generate_content') }}" 

1406 ) 

1407 generate_content >> send_email 

1408 

1409 """ 

1410 from airflow.sdk.definitions.xcom_arg import XComArg 

1411 

1412 if field not in self.template_fields: 

1413 return 

1414 XComArg.apply_upstream_relationship(self, newvalue) 

1415 

1416 def on_kill(self) -> None: 

1417 """ 

1418 Override this method to clean up subprocesses when a task instance gets killed. 

1419 

1420 Any use of the threading, subprocess or multiprocessing module within an 

1421 operator needs to be cleaned up, or it will leave ghost processes behind. 

1422 """ 

1423 

1424 def __repr__(self): 

1425 return f"<Task({self.task_type}): {self.task_id}>" 

1426 

1427 @property 

1428 def operator_class(self) -> type[BaseOperator]: # type: ignore[override] 

1429 return self.__class__ 

1430 

1431 @property 

1432 def task_type(self) -> str: 

1433 """@property: type of the task.""" 

1434 return self.__class__.__name__ 

1435 

1436 @property 

1437 def operator_name(self) -> str: 

1438 """@property: use a more friendly display name for the operator, if set.""" 

1439 try: 

1440 return self.custom_operator_name # type: ignore 

1441 except AttributeError: 

1442 return self.task_type 

1443 

1444 @property 

1445 def roots(self) -> list[BaseOperator]: 

1446 """Required by DAGNode.""" 

1447 return [self] 

1448 

1449 @property 

1450 def leaves(self) -> list[BaseOperator]: 

1451 """Required by DAGNode.""" 

1452 return [self] 

1453 

1454 @property 

1455 def output(self) -> XComArg: 

1456 """Returns reference to XCom pushed by current operator.""" 

1457 from airflow.sdk.definitions.xcom_arg import XComArg 

1458 

1459 return XComArg(operator=self) 

1460 

1461 @classmethod 

1462 def get_serialized_fields(cls): 

1463 """Stringified Dags and operators contain exactly these fields.""" 

1464 if not cls.__serialized_fields: 

1465 from airflow.sdk.definitions._internal.contextmanager import DagContext 

1466 

1467 # make sure the following "fake" task is not added to current active 

1468 # dag in context, otherwise, it will result in 

1469 # `RuntimeError: dictionary changed size during iteration` 

1470 # Exception in SerializedDAG.serialize_dag() call. 

1471 DagContext.push(None) 

1472 cls.__serialized_fields = frozenset( 

1473 vars(BaseOperator(task_id="test")).keys() 

1474 - { 

1475 "upstream_task_ids", 

1476 "default_args", 

1477 "dag", 

1478 "_dag", 

1479 "label", 

1480 "_BaseOperator__instantiated", 

1481 "_BaseOperator__init_kwargs", 

1482 "_BaseOperator__from_mapped", 

1483 "on_failure_fail_dagrun", 

1484 "task_group", 

1485 "_task_type", 

1486 "operator_extra_links", 

1487 "on_execute_callback", 

1488 "on_failure_callback", 

1489 "on_success_callback", 

1490 "on_retry_callback", 

1491 "on_skipped_callback", 

1492 } 

1493 | { # Class level defaults, or `@property` need to be added to this list 

1494 "start_date", 

1495 "end_date", 

1496 "task_type", 

1497 "ui_color", 

1498 "ui_fgcolor", 

1499 "template_ext", 

1500 "template_fields", 

1501 "template_fields_renderers", 

1502 "params", 

1503 "is_setup", 

1504 "is_teardown", 

1505 "on_failure_fail_dagrun", 

1506 "map_index_template", 

1507 "start_trigger_args", 

1508 "_needs_expansion", 

1509 "start_from_trigger", 

1510 "max_retry_delay", 

1511 "has_on_execute_callback", 

1512 "has_on_failure_callback", 

1513 "has_on_success_callback", 

1514 "has_on_retry_callback", 

1515 "has_on_skipped_callback", 

1516 } 

1517 ) 

1518 DagContext.pop() 

1519 

1520 return cls.__serialized_fields 

1521 

1522 def prepare_for_execution(self) -> Self: 

1523 """Lock task for execution to disable custom action in ``__setattr__`` and return a copy.""" 

1524 other = copy.copy(self) 

1525 other._lock_for_execution = True 

1526 return other 

1527 

1528 def serialize_for_task_group(self) -> tuple[DagAttributeTypes, Any]: 

1529 """Serialize; required by DAGNode.""" 

1530 from airflow.serialization.enums import DagAttributeTypes 

1531 

1532 return DagAttributeTypes.OP, self.task_id 

1533 

1534 def unmap(self, resolve: None | Mapping[str, Any]) -> Self: 

1535 """ 

1536 Get the "normal" operator from the current operator. 

1537 

1538 Since a BaseOperator is not mapped to begin with, this simply returns 

1539 the original operator. 

1540 

1541 :meta private: 

1542 """ 

1543 return self 

1544 

1545 def expand_start_trigger_args(self, *, context: Context) -> StartTriggerArgs | None: 

1546 """ 

1547 Get the start_trigger_args value of the current abstract operator. 

1548 

1549 Since a BaseOperator is not mapped to begin with, this simply returns 

1550 the original value of start_trigger_args. 

1551 

1552 :meta private: 

1553 """ 

1554 return self.start_trigger_args 

1555 

1556 def render_template_fields( 

1557 self, 

1558 context: Context, 

1559 jinja_env: jinja2.Environment | None = None, 

1560 ) -> None: 

1561 """ 

1562 Template all attributes listed in *self.template_fields*. 

1563 

1564 This mutates the attributes in-place and is irreversible. 

1565 

1566 :param context: Context dict with values to apply on content. 

1567 :param jinja_env: Jinja's environment to use for rendering. 

1568 """ 

1569 if not jinja_env: 

1570 jinja_env = self.get_template_env() 

1571 self._do_render_template_fields(self, self.template_fields, context, jinja_env, set()) 

1572 

1573 def pre_execute(self, context: Any): 

1574 """Execute right before self.execute() is called.""" 

1575 

1576 def execute(self, context: Context) -> Any: 

1577 """ 

1578 Derive when creating an operator. 

1579 

1580 The main method to execute the task. Context is the same dictionary used 

1581 as when rendering jinja templates. 

1582 

1583 Refer to get_template_context for more context. 

1584 """ 

1585 raise NotImplementedError() 

1586 

1587 def post_execute(self, context: Any, result: Any = None): 

1588 """ 

1589 Execute right after self.execute() is called. 

1590 

1591 It is passed the execution context and any results returned by the operator. 

1592 """ 

1593 

1594 def defer( 

1595 self, 

1596 *, 

1597 trigger: BaseTrigger, 

1598 method_name: str, 

1599 kwargs: dict[str, Any] | None = None, 

1600 timeout: timedelta | int | float | None = None, 

1601 ) -> NoReturn: 

1602 """ 

1603 Mark this Operator "deferred", suspending its execution until the provided trigger fires an event. 

1604 

1605 This is achieved by raising a special exception (TaskDeferred) 

1606 which is caught in the main _execute_task wrapper. Triggers can send execution back to task or end 

1607 the task instance directly. If the trigger will end the task instance itself, ``method_name`` should 

1608 be None; otherwise, provide the name of the method that should be used when resuming execution in 

1609 the task. 

1610 """ 

1611 from airflow.sdk.exceptions import TaskDeferred 

1612 

1613 raise TaskDeferred(trigger=trigger, method_name=method_name, kwargs=kwargs, timeout=timeout) 

1614 

1615 def resume_execution(self, next_method: str, next_kwargs: dict[str, Any] | None, context: Context): 

1616 """Entrypoint method called by the Task Runner (instead of execute) when this task is resumed.""" 

1617 from airflow.sdk.exceptions import TaskDeferralError, TaskDeferralTimeout 

1618 

1619 if next_kwargs is None: 

1620 next_kwargs = {} 

1621 # __fail__ is a special signal value for next_method that indicates 

1622 # this task was scheduled specifically to fail. 

1623 

1624 if next_method == TRIGGER_FAIL_REPR: 

1625 next_kwargs = next_kwargs or {} 

1626 traceback = next_kwargs.get("traceback") 

1627 if traceback is not None: 

1628 self.log.error("Trigger failed:\n%s", "\n".join(traceback)) 

1629 if (error := next_kwargs.get("error", "Unknown")) == TriggerFailureReason.TRIGGER_TIMEOUT: 

1630 raise TaskDeferralTimeout(error) 

1631 raise TaskDeferralError(error) 

1632 # Grab the callable off the Operator/Task and add in any kwargs 

1633 execute_callable = getattr(self, next_method) 

1634 return execute_callable(context, **next_kwargs) 

1635 

1636 def dry_run(self) -> None: 

1637 """Perform dry run for the operator - just render template fields.""" 

1638 self.log.info("Dry run") 

1639 for f in self.template_fields: 

1640 try: 

1641 content = getattr(self, f) 

1642 except AttributeError: 

1643 raise AttributeError( 

1644 f"{f!r} is configured as a template field " 

1645 f"but {self.task_type} does not have this attribute." 

1646 ) 

1647 

1648 if content and isinstance(content, str): 

1649 self.log.info("Rendering template for %s", f) 

1650 self.log.info(content) 

1651 

1652 @property 

1653 def has_on_execute_callback(self) -> bool: 

1654 """Return True if the task has execute callbacks.""" 

1655 return bool(self.on_execute_callback) 

1656 

1657 @property 

1658 def has_on_failure_callback(self) -> bool: 

1659 """Return True if the task has failure callbacks.""" 

1660 return bool(self.on_failure_callback) 

1661 

1662 @property 

1663 def has_on_success_callback(self) -> bool: 

1664 """Return True if the task has success callbacks.""" 

1665 return bool(self.on_success_callback) 

1666 

1667 @property 

1668 def has_on_retry_callback(self) -> bool: 

1669 """Return True if the task has retry callbacks.""" 

1670 return bool(self.on_retry_callback) 

1671 

1672 @property 

1673 def has_on_skipped_callback(self) -> bool: 

1674 """Return True if the task has skipped callbacks.""" 

1675 return bool(self.on_skipped_callback) 

1676 

1677 

1678def chain(*tasks: DependencyMixin | Sequence[DependencyMixin]) -> None: 

1679 r""" 

1680 Given a number of tasks, builds a dependency chain. 

1681 

1682 This function accepts values of BaseOperator (aka tasks), EdgeModifiers (aka Labels), XComArg, TaskGroups, 

1683 or lists containing any mix of these types (or a mix in the same list). If you want to chain between two 

1684 lists you must ensure they have the same length. 

1685 

1686 Using classic operators/sensors: 

1687 

1688 .. code-block:: python 

1689 

1690 chain(t1, [t2, t3], [t4, t5], t6) 

1691 

1692 is equivalent to:: 

1693 

1694 / -> t2 -> t4 \ 

1695 t1 -> t6 

1696 \ -> t3 -> t5 / 

1697 

1698 .. code-block:: python 

1699 

1700 t1.set_downstream(t2) 

1701 t1.set_downstream(t3) 

1702 t2.set_downstream(t4) 

1703 t3.set_downstream(t5) 

1704 t4.set_downstream(t6) 

1705 t5.set_downstream(t6) 

1706 

1707 Using task-decorated functions aka XComArgs: 

1708 

1709 .. code-block:: python 

1710 

1711 chain(x1(), [x2(), x3()], [x4(), x5()], x6()) 

1712 

1713 is equivalent to:: 

1714 

1715 / -> x2 -> x4 \ 

1716 x1 -> x6 

1717 \ -> x3 -> x5 / 

1718 

1719 .. code-block:: python 

1720 

1721 x1 = x1() 

1722 x2 = x2() 

1723 x3 = x3() 

1724 x4 = x4() 

1725 x5 = x5() 

1726 x6 = x6() 

1727 x1.set_downstream(x2) 

1728 x1.set_downstream(x3) 

1729 x2.set_downstream(x4) 

1730 x3.set_downstream(x5) 

1731 x4.set_downstream(x6) 

1732 x5.set_downstream(x6) 

1733 

1734 Using TaskGroups: 

1735 

1736 .. code-block:: python 

1737 

1738 chain(t1, task_group1, task_group2, t2) 

1739 

1740 t1.set_downstream(task_group1) 

1741 task_group1.set_downstream(task_group2) 

1742 task_group2.set_downstream(t2) 

1743 

1744 

1745 It is also possible to mix between classic operator/sensor, EdgeModifiers, XComArg, and TaskGroups: 

1746 

1747 .. code-block:: python 

1748 

1749 chain(t1, [Label("branch one"), Label("branch two")], [x1(), x2()], task_group1, x3()) 

1750 

1751 is equivalent to:: 

1752 

1753 / "branch one" -> x1 \ 

1754 t1 -> task_group1 -> x3 

1755 \ "branch two" -> x2 / 

1756 

1757 .. code-block:: python 

1758 

1759 x1 = x1() 

1760 x2 = x2() 

1761 x3 = x3() 

1762 label1 = Label("branch one") 

1763 label2 = Label("branch two") 

1764 t1.set_downstream(label1) 

1765 label1.set_downstream(x1) 

1766 t2.set_downstream(label2) 

1767 label2.set_downstream(x2) 

1768 x1.set_downstream(task_group1) 

1769 x2.set_downstream(task_group1) 

1770 task_group1.set_downstream(x3) 

1771 

1772 # or 

1773 

1774 x1 = x1() 

1775 x2 = x2() 

1776 x3 = x3() 

1777 t1.set_downstream(x1, edge_modifier=Label("branch one")) 

1778 t1.set_downstream(x2, edge_modifier=Label("branch two")) 

1779 x1.set_downstream(task_group1) 

1780 x2.set_downstream(task_group1) 

1781 task_group1.set_downstream(x3) 

1782 

1783 

1784 :param tasks: Individual and/or list of tasks, EdgeModifiers, XComArgs, or TaskGroups to set dependencies 

1785 """ 

1786 for up_task, down_task in zip(tasks, tasks[1:]): 

1787 if isinstance(up_task, DependencyMixin): 

1788 up_task.set_downstream(down_task) 

1789 continue 

1790 if isinstance(down_task, DependencyMixin): 

1791 down_task.set_upstream(up_task) 

1792 continue 

1793 if not isinstance(up_task, Sequence) or not isinstance(down_task, Sequence): 

1794 raise TypeError(f"Chain not supported between instances of {type(up_task)} and {type(down_task)}") 

1795 up_task_list = up_task 

1796 down_task_list = down_task 

1797 if len(up_task_list) != len(down_task_list): 

1798 raise ValueError( 

1799 f"Chain not supported for different length Iterable. " 

1800 f"Got {len(up_task_list)} and {len(down_task_list)}." 

1801 ) 

1802 for up_t, down_t in zip(up_task_list, down_task_list): 

1803 up_t.set_downstream(down_t) 

1804 

1805 

1806def cross_downstream( 

1807 from_tasks: Sequence[DependencyMixin], 

1808 to_tasks: DependencyMixin | Sequence[DependencyMixin], 

1809): 

1810 r""" 

1811 Set downstream dependencies for all tasks in from_tasks to all tasks in to_tasks. 

1812 

1813 Using classic operators/sensors: 

1814 

1815 .. code-block:: python 

1816 

1817 cross_downstream(from_tasks=[t1, t2, t3], to_tasks=[t4, t5, t6]) 

1818 

1819 is equivalent to:: 

1820 

1821 t1 ---> t4 

1822 \ / 

1823 t2 -X -> t5 

1824 / \ 

1825 t3 ---> t6 

1826 

1827 .. code-block:: python 

1828 

1829 t1.set_downstream(t4) 

1830 t1.set_downstream(t5) 

1831 t1.set_downstream(t6) 

1832 t2.set_downstream(t4) 

1833 t2.set_downstream(t5) 

1834 t2.set_downstream(t6) 

1835 t3.set_downstream(t4) 

1836 t3.set_downstream(t5) 

1837 t3.set_downstream(t6) 

1838 

1839 Using task-decorated functions aka XComArgs: 

1840 

1841 .. code-block:: python 

1842 

1843 cross_downstream(from_tasks=[x1(), x2(), x3()], to_tasks=[x4(), x5(), x6()]) 

1844 

1845 is equivalent to:: 

1846 

1847 x1 ---> x4 

1848 \ / 

1849 x2 -X -> x5 

1850 / \ 

1851 x3 ---> x6 

1852 

1853 .. code-block:: python 

1854 

1855 x1 = x1() 

1856 x2 = x2() 

1857 x3 = x3() 

1858 x4 = x4() 

1859 x5 = x5() 

1860 x6 = x6() 

1861 x1.set_downstream(x4) 

1862 x1.set_downstream(x5) 

1863 x1.set_downstream(x6) 

1864 x2.set_downstream(x4) 

1865 x2.set_downstream(x5) 

1866 x2.set_downstream(x6) 

1867 x3.set_downstream(x4) 

1868 x3.set_downstream(x5) 

1869 x3.set_downstream(x6) 

1870 

1871 It is also possible to mix between classic operator/sensor and XComArg tasks: 

1872 

1873 .. code-block:: python 

1874 

1875 cross_downstream(from_tasks=[t1, x2(), t3], to_tasks=[x1(), t2, x3()]) 

1876 

1877 is equivalent to:: 

1878 

1879 t1 ---> x1 

1880 \ / 

1881 x2 -X -> t2 

1882 / \ 

1883 t3 ---> x3 

1884 

1885 .. code-block:: python 

1886 

1887 x1 = x1() 

1888 x2 = x2() 

1889 x3 = x3() 

1890 t1.set_downstream(x1) 

1891 t1.set_downstream(t2) 

1892 t1.set_downstream(x3) 

1893 x2.set_downstream(x1) 

1894 x2.set_downstream(t2) 

1895 x2.set_downstream(x3) 

1896 t3.set_downstream(x1) 

1897 t3.set_downstream(t2) 

1898 t3.set_downstream(x3) 

1899 

1900 :param from_tasks: List of tasks or XComArgs to start from. 

1901 :param to_tasks: List of tasks or XComArgs to set as downstream dependencies. 

1902 """ 

1903 for task in from_tasks: 

1904 task.set_downstream(to_tasks) 

1905 

1906 

1907def chain_linear(*elements: DependencyMixin | Sequence[DependencyMixin]): 

1908 """ 

1909 Simplify task dependency definition. 

1910 

1911 E.g.: suppose you want precedence like so:: 

1912 

1913 ╭─op2─╮ ╭─op4─╮ 

1914 op1─┤ ├─├─op5─┤─op7 

1915 ╰-op3─╯ ╰-op6─╯ 

1916 

1917 Then you can accomplish like so:: 

1918 

1919 chain_linear(op1, [op2, op3], [op4, op5, op6], op7) 

1920 

1921 :param elements: a list of operators / lists of operators 

1922 """ 

1923 if not elements: 

1924 raise ValueError("No tasks provided; nothing to do.") 

1925 prev_elem = None 

1926 deps_set = False 

1927 for curr_elem in elements: 

1928 if isinstance(curr_elem, EdgeModifier): 

1929 raise ValueError("Labels are not supported by chain_linear") 

1930 if prev_elem is not None: 

1931 for task in prev_elem: 

1932 task >> curr_elem 

1933 if not deps_set: 

1934 deps_set = True 

1935 prev_elem = [curr_elem] if isinstance(curr_elem, DependencyMixin) else curr_elem 

1936 if not deps_set: 

1937 raise ValueError("No dependencies were set. Did you forget to expand with `*`?")