Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/airflow/sdk/bases/operator.py: 36%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

711 statements  

1# Licensed to the Apache Software Foundation (ASF) under one 

2# or more contributor license agreements. See the NOTICE file 

3# distributed with this work for additional information 

4# regarding copyright ownership. The ASF licenses this file 

5# to you under the Apache License, Version 2.0 (the 

6# "License"); you may not use this file except in compliance 

7# with the License. You may obtain a copy of the License at 

8# 

9# http://www.apache.org/licenses/LICENSE-2.0 

10# 

11# Unless required by applicable law or agreed to in writing, 

12# software distributed under the License is distributed on an 

13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 

14# KIND, either express or implied. See the License for the 

15# specific language governing permissions and limitations 

16# under the License. 

17 

18from __future__ import annotations 

19 

20import abc 

21import asyncio 

22import collections.abc 

23import contextlib 

24import copy 

25import inspect 

26import sys 

27import warnings 

28from asyncio import AbstractEventLoop 

29from collections.abc import Callable, Collection, Generator, Iterable, Mapping, Sequence 

30from contextvars import ContextVar 

31from dataclasses import dataclass, field 

32from datetime import datetime, timedelta 

33from enum import Enum 

34from functools import total_ordering, wraps 

35from types import FunctionType 

36from typing import TYPE_CHECKING, Any, ClassVar, Final, NoReturn, TypeVar, cast 

37 

38import attrs 

39 

40from airflow.sdk import TriggerRule, timezone 

41from airflow.sdk._shared.secrets_masker import redact 

42from airflow.sdk.definitions._internal.abstractoperator import ( 

43 DEFAULT_EMAIL_ON_FAILURE, 

44 DEFAULT_EMAIL_ON_RETRY, 

45 DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST, 

46 DEFAULT_OWNER, 

47 DEFAULT_POOL_NAME, 

48 DEFAULT_POOL_SLOTS, 

49 DEFAULT_PRIORITY_WEIGHT, 

50 DEFAULT_QUEUE, 

51 DEFAULT_RETRIES, 

52 DEFAULT_RETRY_DELAY, 

53 DEFAULT_TASK_EXECUTION_TIMEOUT, 

54 DEFAULT_TRIGGER_RULE, 

55 DEFAULT_WAIT_FOR_PAST_DEPENDS_BEFORE_SKIPPING, 

56 DEFAULT_WEIGHT_RULE, 

57 AbstractOperator, 

58 DependencyMixin, 

59 TaskStateChangeCallback, 

60) 

61from airflow.sdk.definitions._internal.decorators import fixup_decorator_warning_stack 

62from airflow.sdk.definitions._internal.node import validate_key 

63from airflow.sdk.definitions._internal.setup_teardown import SetupTeardownContext 

64from airflow.sdk.definitions._internal.types import NOTSET, validate_instance_args 

65from airflow.sdk.definitions.edges import EdgeModifier 

66from airflow.sdk.definitions.mappedoperator import OperatorPartial, validate_mapping_kwargs 

67from airflow.sdk.definitions.param import ParamsDict 

68from airflow.sdk.exceptions import RemovedInAirflow4Warning 

69 

70# Databases do not support arbitrary precision integers, so we need to limit the range of priority weights. 

71# postgres: -2147483648 to +2147483647 (see https://www.postgresql.org/docs/current/datatype-numeric.html) 

72# mysql: -2147483648 to +2147483647 (see https://dev.mysql.com/doc/refman/8.4/en/integer-types.html) 

73# sqlite: -9223372036854775808 to +9223372036854775807 (see https://sqlite.org/datatype3.html) 

74DB_SAFE_MINIMUM = -2147483648 

75DB_SAFE_MAXIMUM = 2147483647 

76 

77 

78def db_safe_priority(priority_weight: int) -> int: 

79 """Convert priority weight to a safe value for the database.""" 

80 return max(DB_SAFE_MINIMUM, min(DB_SAFE_MAXIMUM, priority_weight)) 

81 

82 

83C = TypeVar("C", bound=Callable) 

84T = TypeVar("T", bound=FunctionType) 

85 

86if TYPE_CHECKING: 

87 from types import ClassMethodDescriptorType 

88 

89 import jinja2 

90 from typing_extensions import Self 

91 

92 from airflow.sdk.api.datamodels._generated import DagAttributeTypes 

93 from airflow.sdk.bases.operatorlink import BaseOperatorLink 

94 from airflow.sdk.definitions.context import Context 

95 from airflow.sdk.definitions.dag import DAG 

96 from airflow.sdk.definitions.operator_resources import Resources 

97 from airflow.sdk.definitions.taskgroup import TaskGroup 

98 from airflow.sdk.definitions.xcom_arg import XComArg 

99 from airflow.task.priority_strategy import PriorityWeightStrategy 

100 from airflow.triggers.base import BaseTrigger, StartTriggerArgs 

101 

102 TaskPreExecuteHook = Callable[[Context], None] 

103 TaskPostExecuteHook = Callable[[Context, Any], None] 

104 

105__all__ = [ 

106 "BaseAsyncOperator", 

107 "BaseOperator", 

108 "chain", 

109 "chain_linear", 

110 "cross_downstream", 

111] 

112 

113 

114class TriggerFailureReason(str, Enum): 

115 """ 

116 Reasons for trigger failures. 

117 

118 Internal use only. 

119 

120 :meta private: 

121 """ 

122 

123 TRIGGER_TIMEOUT = "Trigger timeout" 

124 TRIGGER_FAILURE = "Trigger failure" 

125 

126 

127TRIGGER_FAIL_REPR = "__fail__" 

128"""String value to represent trigger failure. 

129 

130Internal use only. 

131 

132:meta private: 

133""" 

134 

135 

136def _get_parent_defaults(dag: DAG | None, task_group: TaskGroup | None) -> tuple[dict, ParamsDict]: 

137 if not dag: 

138 return {}, ParamsDict() 

139 dag_args = copy.copy(dag.default_args) 

140 dag_params = copy.deepcopy(dag.params) 

141 dag_params._fill_missing_param_source("dag") 

142 if task_group: 

143 if task_group.default_args and not isinstance(task_group.default_args, collections.abc.Mapping): 

144 raise TypeError("default_args must be a mapping") 

145 dag_args.update(task_group.default_args) 

146 return dag_args, dag_params 

147 

148 

149def get_merged_defaults( 

150 dag: DAG | None, 

151 task_group: TaskGroup | None, 

152 task_params: collections.abc.MutableMapping | None, 

153 task_default_args: dict | None, 

154) -> tuple[dict, ParamsDict]: 

155 args, params = _get_parent_defaults(dag, task_group) 

156 if task_params: 

157 if not isinstance(task_params, collections.abc.Mapping): 

158 raise TypeError(f"params must be a mapping, got {type(task_params)}") 

159 

160 task_params = ParamsDict(task_params) 

161 task_params._fill_missing_param_source("task") 

162 params.update(task_params) 

163 

164 if task_default_args: 

165 if not isinstance(task_default_args, collections.abc.Mapping): 

166 raise TypeError(f"default_args must be a mapping, got {type(task_params)}") 

167 args.update(task_default_args) 

168 with contextlib.suppress(KeyError): 

169 if params_from_default_args := ParamsDict(task_default_args["params"] or {}): 

170 params_from_default_args._fill_missing_param_source("task") 

171 params.update(params_from_default_args) 

172 

173 return args, params 

174 

175 

176def parse_retries(retries: Any) -> int | None: 

177 if retries is None: 

178 return 0 

179 if type(retries) == int: # noqa: E721 

180 return retries 

181 try: 

182 parsed_retries = int(retries) 

183 except (TypeError, ValueError): 

184 raise RuntimeError(f"'retries' type must be int, not {type(retries).__name__}") 

185 return parsed_retries 

186 

187 

188def coerce_timedelta(value: float | timedelta, *, key: str | None = None) -> timedelta: 

189 if isinstance(value, timedelta): 

190 return value 

191 return timedelta(seconds=value) 

192 

193 

194def coerce_resources(resources: dict[str, Any] | None) -> Resources | None: 

195 if resources is None: 

196 return None 

197 from airflow.sdk.definitions.operator_resources import Resources 

198 

199 return Resources(**resources) 

200 

201 

202@contextlib.contextmanager 

203def event_loop() -> Generator[AbstractEventLoop]: 

204 new_event_loop = False 

205 loop = None 

206 try: 

207 try: 

208 loop = asyncio.get_event_loop() 

209 if loop.is_closed(): 

210 raise RuntimeError 

211 except RuntimeError: 

212 loop = asyncio.new_event_loop() 

213 asyncio.set_event_loop(loop) 

214 new_event_loop = True 

215 yield loop 

216 finally: 

217 if new_event_loop and loop is not None: 

218 with contextlib.suppress(AttributeError): 

219 loop.close() 

220 asyncio.set_event_loop(None) 

221 

222 

223class _PartialDescriptor: 

224 """A descriptor that guards against ``.partial`` being called on Task objects.""" 

225 

226 class_method: ClassMethodDescriptorType | None = None 

227 

228 def __get__( 

229 self, obj: BaseOperator, cls: type[BaseOperator] | None = None 

230 ) -> Callable[..., OperatorPartial]: 

231 # Call this "partial" so it looks nicer in stack traces. 

232 def partial(**kwargs): 

233 raise TypeError("partial can only be called on Operator classes, not Tasks themselves") 

234 

235 if obj is not None: 

236 return partial 

237 return self.class_method.__get__(cls, cls) 

238 

239 

240OPERATOR_DEFAULTS: dict[str, Any] = { 

241 "allow_nested_operators": True, 

242 "depends_on_past": False, 

243 "email_on_failure": DEFAULT_EMAIL_ON_FAILURE, 

244 "email_on_retry": DEFAULT_EMAIL_ON_RETRY, 

245 "execution_timeout": DEFAULT_TASK_EXECUTION_TIMEOUT, 

246 # "executor": DEFAULT_EXECUTOR, 

247 "executor_config": {}, 

248 "ignore_first_depends_on_past": DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST, 

249 "inlets": [], 

250 "map_index_template": None, 

251 "on_execute_callback": [], 

252 "on_failure_callback": [], 

253 "on_retry_callback": [], 

254 "on_skipped_callback": [], 

255 "on_success_callback": [], 

256 "outlets": [], 

257 "owner": DEFAULT_OWNER, 

258 "pool_slots": DEFAULT_POOL_SLOTS, 

259 "priority_weight": DEFAULT_PRIORITY_WEIGHT, 

260 "queue": DEFAULT_QUEUE, 

261 "retries": DEFAULT_RETRIES, 

262 "retry_delay": DEFAULT_RETRY_DELAY, 

263 "retry_exponential_backoff": 0, 

264 "trigger_rule": DEFAULT_TRIGGER_RULE, 

265 "wait_for_past_depends_before_skipping": DEFAULT_WAIT_FOR_PAST_DEPENDS_BEFORE_SKIPPING, 

266 "wait_for_downstream": False, 

267 "weight_rule": DEFAULT_WEIGHT_RULE, 

268} 

269 

270 

271# This is what handles the actual mapping. 

272 

273if TYPE_CHECKING: 

274 

275 def partial( 

276 operator_class: type[BaseOperator], 

277 *, 

278 task_id: str, 

279 dag: DAG | None = None, 

280 task_group: TaskGroup | None = None, 

281 start_date: datetime = ..., 

282 end_date: datetime = ..., 

283 owner: str = ..., 

284 email: None | str | Iterable[str] = ..., 

285 params: collections.abc.MutableMapping | None = None, 

286 resources: dict[str, Any] | None = ..., 

287 trigger_rule: str = ..., 

288 depends_on_past: bool = ..., 

289 ignore_first_depends_on_past: bool = ..., 

290 wait_for_past_depends_before_skipping: bool = ..., 

291 wait_for_downstream: bool = ..., 

292 retries: int | None = ..., 

293 queue: str = ..., 

294 pool: str = ..., 

295 pool_slots: int = ..., 

296 execution_timeout: timedelta | None = ..., 

297 max_retry_delay: None | timedelta | float = ..., 

298 retry_delay: timedelta | float = ..., 

299 retry_exponential_backoff: float = ..., 

300 priority_weight: int = ..., 

301 weight_rule: str | PriorityWeightStrategy = ..., 

302 sla: timedelta | None = ..., 

303 map_index_template: str | None = ..., 

304 max_active_tis_per_dag: int | None = ..., 

305 max_active_tis_per_dagrun: int | None = ..., 

306 on_execute_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = ..., 

307 on_failure_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = ..., 

308 on_success_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = ..., 

309 on_retry_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = ..., 

310 on_skipped_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = ..., 

311 run_as_user: str | None = ..., 

312 executor: str | None = ..., 

313 executor_config: dict | None = ..., 

314 inlets: Any | None = ..., 

315 outlets: Any | None = ..., 

316 doc: str | None = ..., 

317 doc_md: str | None = ..., 

318 doc_json: str | None = ..., 

319 doc_yaml: str | None = ..., 

320 doc_rst: str | None = ..., 

321 task_display_name: str | None = ..., 

322 logger_name: str | None = ..., 

323 allow_nested_operators: bool = True, 

324 **kwargs, 

325 ) -> OperatorPartial: ... 

326else: 

327 

328 def partial( 

329 operator_class: type[BaseOperator], 

330 *, 

331 task_id: str, 

332 dag: DAG | None = None, 

333 task_group: TaskGroup | None = None, 

334 params: collections.abc.MutableMapping | None = None, 

335 **kwargs, 

336 ): 

337 from airflow.sdk.definitions._internal.contextmanager import DagContext, TaskGroupContext 

338 

339 validate_mapping_kwargs(operator_class, "partial", kwargs) 

340 

341 dag = dag or DagContext.get_current() 

342 if dag: 

343 task_group = task_group or TaskGroupContext.get_current(dag) 

344 if task_group: 

345 task_id = task_group.child_id(task_id) 

346 

347 # Merge Dag and task group level defaults into user-supplied values. 

348 dag_default_args, partial_params = get_merged_defaults( 

349 dag=dag, 

350 task_group=task_group, 

351 task_params=params, 

352 task_default_args=kwargs.pop("default_args", None), 

353 ) 

354 

355 # Create partial_kwargs from args and kwargs 

356 partial_kwargs: dict[str, Any] = { 

357 "task_id": task_id, 

358 "dag": dag, 

359 "task_group": task_group, 

360 **kwargs, 

361 } 

362 

363 # Inject Dag-level default args into args provided to this function. 

364 # Most of the default args will be retrieved during unmapping; here we 

365 # only ensure base properties are correctly set for the scheduler. 

366 partial_kwargs.update( 

367 (k, v) 

368 for k, v in dag_default_args.items() 

369 if k not in partial_kwargs and k in BaseOperator.__init__._BaseOperatorMeta__param_names 

370 ) 

371 

372 # Fill fields not provided by the user with default values. 

373 partial_kwargs.update((k, v) for k, v in OPERATOR_DEFAULTS.items() if k not in partial_kwargs) 

374 

375 # Post-process arguments. Should be kept in sync with _TaskDecorator.expand(). 

376 if "task_concurrency" in kwargs: # Reject deprecated option. 

377 raise TypeError("unexpected argument: task_concurrency") 

378 if start_date := partial_kwargs.get("start_date", None): 

379 partial_kwargs["start_date"] = timezone.convert_to_utc(start_date) 

380 if end_date := partial_kwargs.get("end_date", None): 

381 partial_kwargs["end_date"] = timezone.convert_to_utc(end_date) 

382 if partial_kwargs["pool_slots"] < 1: 

383 dag_str = "" 

384 if dag: 

385 dag_str = f" in dag {dag.dag_id}" 

386 raise ValueError(f"pool slots for {task_id}{dag_str} cannot be less than 1") 

387 if retries := partial_kwargs.get("retries"): 

388 partial_kwargs["retries"] = BaseOperator._convert_retries(retries) 

389 partial_kwargs["retry_delay"] = BaseOperator._convert_retry_delay(partial_kwargs["retry_delay"]) 

390 partial_kwargs["max_retry_delay"] = BaseOperator._convert_max_retry_delay( 

391 partial_kwargs.get("max_retry_delay", None) 

392 ) 

393 

394 for k in ("execute", "failure", "success", "retry", "skipped"): 

395 partial_kwargs[attr] = _collect_from_input(partial_kwargs.get(attr := f"on_{k}_callback")) 

396 

397 return OperatorPartial( 

398 operator_class=operator_class, 

399 kwargs=partial_kwargs, 

400 params=partial_params, 

401 ) 

402 

403 

404class ExecutorSafeguard: 

405 """ 

406 The ExecutorSafeguard decorator. 

407 

408 Checks if the execute method of an operator isn't manually called outside 

409 the TaskInstance as we want to avoid bad mixing between decorated and 

410 classic operators. 

411 """ 

412 

413 test_mode: ClassVar[bool] = False 

414 tracker: ClassVar[ContextVar[BaseOperator]] = ContextVar("ExecutorSafeguard_sentinel") 

415 sentinel_value: ClassVar[object] = object() 

416 

417 @classmethod 

418 def decorator(cls, func): 

419 @wraps(func) 

420 def wrapper(self, *args, **kwargs): 

421 sentinel_key = f"{self.__class__.__name__}__sentinel" 

422 sentinel = kwargs.pop(sentinel_key, None) 

423 

424 with contextlib.ExitStack() as stack: 

425 if sentinel is cls.sentinel_value: 

426 token = cls.tracker.set(self) 

427 sentinel = self 

428 stack.callback(cls.tracker.reset, token) 

429 else: 

430 # No sentinel passed in, maybe the subclass execute did have it passed? 

431 sentinel = cls.tracker.get(None) 

432 

433 if not cls.test_mode and sentinel is not self: 

434 message = f"{self.__class__.__name__}.{func.__name__} cannot be called outside of the Task Runner!" 

435 if not self.allow_nested_operators: 

436 raise RuntimeError(message) 

437 self.log.warning(message) 

438 

439 # Now that we've logged, set sentinel so that `super()` calls don't log again 

440 token = cls.tracker.set(self) 

441 stack.callback(cls.tracker.reset, token) 

442 

443 return func(self, *args, **kwargs) 

444 

445 return wrapper 

446 

447 

448if "airflow.configuration" in sys.modules: 

449 # Don't try and import it if it's not already loaded 

450 from airflow.sdk.configuration import conf 

451 

452 ExecutorSafeguard.test_mode = conf.getboolean("core", "unit_test_mode") 

453 

454 

455def _collect_from_input(value_or_values: None | C | Collection[C]) -> list[C]: 

456 if not value_or_values: 

457 return [] 

458 if isinstance(value_or_values, Collection): 

459 return list(value_or_values) 

460 return [value_or_values] 

461 

462 

463class BaseOperatorMeta(abc.ABCMeta): 

464 """Metaclass of BaseOperator.""" 

465 

466 @classmethod 

467 def _apply_defaults(cls, func: T) -> T: 

468 """ 

469 Look for an argument named "default_args", and fill the unspecified arguments from it. 

470 

471 Since python2.* isn't clear about which arguments are missing when 

472 calling a function, and that this can be quite confusing with multi-level 

473 inheritance and argument defaults, this decorator also alerts with 

474 specific information about the missing arguments. 

475 """ 

476 # Cache inspect.signature for the wrapper closure to avoid calling it 

477 # at every decorated invocation. This is separate sig_cache created 

478 # per decoration, i.e. each function decorated using apply_defaults will 

479 # have a different sig_cache. 

480 sig_cache = inspect.signature(func) 

481 non_variadic_params = { 

482 name: param 

483 for (name, param) in sig_cache.parameters.items() 

484 if param.name != "self" and param.kind not in (param.VAR_POSITIONAL, param.VAR_KEYWORD) 

485 } 

486 non_optional_args = { 

487 name 

488 for name, param in non_variadic_params.items() 

489 if param.default == param.empty and name != "task_id" 

490 } 

491 

492 fixup_decorator_warning_stack(func) 

493 

494 @wraps(func) 

495 def apply_defaults(self: BaseOperator, *args: Any, **kwargs: Any) -> Any: 

496 from airflow.sdk.definitions._internal.contextmanager import DagContext, TaskGroupContext 

497 

498 if args: 

499 raise TypeError("Use keyword arguments when initializing operators") 

500 

501 instantiated_from_mapped = kwargs.pop( 

502 "_airflow_from_mapped", 

503 getattr(self, "_BaseOperator__from_mapped", False), 

504 ) 

505 

506 dag: DAG | None = kwargs.get("dag") 

507 if dag is None: 

508 dag = DagContext.get_current() 

509 if dag is not None: 

510 kwargs["dag"] = dag 

511 

512 task_group: TaskGroup | None = kwargs.get("task_group") 

513 if dag and not task_group: 

514 task_group = TaskGroupContext.get_current(dag) 

515 if task_group is not None: 

516 kwargs["task_group"] = task_group 

517 

518 default_args, merged_params = get_merged_defaults( 

519 dag=dag, 

520 task_group=task_group, 

521 task_params=kwargs.pop("params", None), 

522 task_default_args=kwargs.pop("default_args", None), 

523 ) 

524 

525 for arg in sig_cache.parameters: 

526 if arg not in kwargs and arg in default_args: 

527 kwargs[arg] = default_args[arg] 

528 

529 missing_args = non_optional_args.difference(kwargs) 

530 if len(missing_args) == 1: 

531 raise TypeError(f"missing keyword argument {missing_args.pop()!r}") 

532 if missing_args: 

533 display = ", ".join(repr(a) for a in sorted(missing_args)) 

534 raise TypeError(f"missing keyword arguments {display}") 

535 

536 if merged_params: 

537 kwargs["params"] = merged_params 

538 

539 hook = getattr(self, "_hook_apply_defaults", None) 

540 if hook: 

541 args, kwargs = hook(**kwargs, default_args=default_args) 

542 default_args = kwargs.pop("default_args", {}) 

543 

544 if not hasattr(self, "_BaseOperator__init_kwargs"): 

545 object.__setattr__(self, "_BaseOperator__init_kwargs", {}) 

546 object.__setattr__(self, "_BaseOperator__from_mapped", instantiated_from_mapped) 

547 

548 result = func(self, **kwargs, default_args=default_args) 

549 

550 # Store the args passed to init -- we need them to support task.map serialization! 

551 self._BaseOperator__init_kwargs.update(kwargs) # type: ignore 

552 

553 # Set upstream task defined by XComArgs passed to template fields of the operator. 

554 # BUT: only do this _ONCE_, not once for each class in the hierarchy 

555 if not instantiated_from_mapped and func == self.__init__.__wrapped__: # type: ignore[misc] 

556 self._set_xcomargs_dependencies() 

557 # Mark instance as instantiated so that future attr setting updates xcomarg-based deps. 

558 object.__setattr__(self, "_BaseOperator__instantiated", True) 

559 

560 return result 

561 

562 apply_defaults.__non_optional_args = non_optional_args # type: ignore 

563 apply_defaults.__param_names = set(non_variadic_params) # type: ignore 

564 

565 return cast("T", apply_defaults) 

566 

567 def __new__(cls, name, bases, namespace, **kwargs): 

568 execute_method = namespace.get("execute") 

569 if callable(execute_method) and not getattr(execute_method, "__isabstractmethod__", False): 

570 namespace["execute"] = ExecutorSafeguard.decorator(execute_method) 

571 new_cls = super().__new__(cls, name, bases, namespace, **kwargs) 

572 with contextlib.suppress(KeyError): 

573 # Update the partial descriptor with the class method, so it calls the actual function 

574 # (but let subclasses override it if they need to) 

575 partial_desc = vars(new_cls)["partial"] 

576 if isinstance(partial_desc, _PartialDescriptor): 

577 partial_desc.class_method = classmethod(partial) 

578 

579 # We patch `__init__` only if the class defines it. 

580 first_superclass = new_cls.mro()[1] 

581 if new_cls.__init__ is not first_superclass.__init__: 

582 new_cls.__init__ = cls._apply_defaults(new_cls.__init__) 

583 

584 return new_cls 

585 

586 

587# TODO: The following mapping is used to validate that the arguments passed to the BaseOperator are of the 

588# correct type. This is a temporary solution until we find a more sophisticated method for argument 

589# validation. One potential method is to use `get_type_hints` from the typing module. However, this is not 

590# fully compatible with future annotations for Python versions below 3.10. Once we require a minimum Python 

591# version that supports `get_type_hints` effectively or find a better approach, we can replace this 

592# manual type-checking method. 

593BASEOPERATOR_ARGS_EXPECTED_TYPES = { 

594 "task_id": str, 

595 "email": (str, Sequence), 

596 "email_on_retry": bool, 

597 "email_on_failure": bool, 

598 "retries": int, 

599 "retry_exponential_backoff": (int, float), 

600 "depends_on_past": bool, 

601 "ignore_first_depends_on_past": bool, 

602 "wait_for_past_depends_before_skipping": bool, 

603 "wait_for_downstream": bool, 

604 "priority_weight": int, 

605 "queue": str, 

606 "pool": str, 

607 "pool_slots": int, 

608 "trigger_rule": str, 

609 "run_as_user": str, 

610 "task_concurrency": int, 

611 "map_index_template": str, 

612 "max_active_tis_per_dag": int, 

613 "max_active_tis_per_dagrun": int, 

614 "executor": str, 

615 "do_xcom_push": bool, 

616 "multiple_outputs": bool, 

617 "doc": str, 

618 "doc_md": str, 

619 "doc_json": str, 

620 "doc_yaml": str, 

621 "doc_rst": str, 

622 "task_display_name": str, 

623 "logger_name": str, 

624 "allow_nested_operators": bool, 

625 "start_date": datetime, 

626 "end_date": datetime, 

627} 

628 

629 

630# Note: BaseOperator is defined as a dataclass, and not an attrs class as we do too much metaprogramming in 

631# here (metaclass, custom `__setattr__` behaviour) and this fights with attrs too much to make it worth it. 

632# 

633# To future reader: if you want to try and make this a "normal" attrs class, go ahead and attempt it. If you 

634# get nowhere leave your record here for the next poor soul and what problems you ran in to. 

635# 

636# @ashb, 2024/10/14 

637# - "Can't combine custom __setattr__ with on_setattr hooks" 

638# - Setting class-wide `define(on_setarrs=...)` isn't called for non-attrs subclasses 

639@total_ordering 

640@dataclass(repr=False) 

641class BaseOperator(AbstractOperator, metaclass=BaseOperatorMeta): 

642 r""" 

643 Abstract base class for all operators. 

644 

645 Since operators create objects that become nodes in the Dag, BaseOperator 

646 contains many recursive methods for Dag crawling behavior. To derive from 

647 this class, you are expected to override the constructor and the 'execute' 

648 method. 

649 

650 Operators derived from this class should perform or trigger certain tasks 

651 synchronously (wait for completion). Example of operators could be an 

652 operator that runs a Pig job (PigOperator), a sensor operator that 

653 waits for a partition to land in Hive (HiveSensorOperator), or one that 

654 moves data from Hive to MySQL (Hive2MySqlOperator). Instances of these 

655 operators (tasks) target specific operations, running specific scripts, 

656 functions or data transfers. 

657 

658 This class is abstract and shouldn't be instantiated. Instantiating a 

659 class derived from this one results in the creation of a task object, 

660 which ultimately becomes a node in Dag objects. Task dependencies should 

661 be set by using the set_upstream and/or set_downstream methods. 

662 

663 :param task_id: a unique, meaningful id for the task 

664 :param owner: the owner of the task. Using a meaningful description 

665 (e.g. user/person/team/role name) to clarify ownership is recommended. 

666 :param email: the 'to' email address(es) used in email alerts. This can be a 

667 single email or multiple ones. Multiple addresses can be specified as a 

668 comma or semicolon separated string or by passing a list of strings. (deprecated) 

669 :param email_on_retry: Indicates whether email alerts should be sent when a 

670 task is retried (deprecated) 

671 :param email_on_failure: Indicates whether email alerts should be sent when 

672 a task failed (deprecated) 

673 :param retries: the number of retries that should be performed before 

674 failing the task 

675 :param retry_delay: delay between retries, can be set as ``timedelta`` or 

676 ``float`` seconds, which will be converted into ``timedelta``, 

677 the default is ``timedelta(seconds=300)``. 

678 :param retry_exponential_backoff: multiplier for exponential backoff between retries. 

679 Set to 0 to disable (constant delay). Set to 2.0 for standard exponential backoff 

680 (delay doubles with each retry). For example, with retry_delay=4min and 

681 retry_exponential_backoff=5, retries occur after 4min, 20min, 100min, etc. 

682 :param max_retry_delay: maximum delay interval between retries, can be set as 

683 ``timedelta`` or ``float`` seconds, which will be converted into ``timedelta``. 

684 :param start_date: The ``start_date`` for the task, determines 

685 the ``logical_date`` for the first task instance. The best practice 

686 is to have the start_date rounded 

687 to your Dag's ``schedule_interval``. Daily jobs have their start_date 

688 some day at 00:00:00, hourly jobs have their start_date at 00:00 

689 of a specific hour. Note that Airflow simply looks at the latest 

690 ``logical_date`` and adds the ``schedule_interval`` to determine 

691 the next ``logical_date``. It is also very important 

692 to note that different tasks' dependencies 

693 need to line up in time. If task A depends on task B and their 

694 start_date are offset in a way that their logical_date don't line 

695 up, A's dependencies will never be met. If you are looking to delay 

696 a task, for example running a daily task at 2AM, look into the 

697 ``TimeSensor`` and ``TimeDeltaSensor``. We advise against using 

698 dynamic ``start_date`` and recommend using fixed ones. Read the 

699 FAQ entry about start_date for more information. 

700 :param end_date: if specified, the scheduler won't go beyond this date 

701 :param depends_on_past: when set to true, task instances will run 

702 sequentially and only if the previous instance has succeeded or has been skipped. 

703 The task instance for the start_date is allowed to run. 

704 :param wait_for_past_depends_before_skipping: when set to true, if the task instance 

705 should be marked as skipped, and depends_on_past is true, the ti will stay on None state 

706 waiting the task of the previous run 

707 :param wait_for_downstream: when set to true, an instance of task 

708 X will wait for tasks immediately downstream of the previous instance 

709 of task X to finish successfully or be skipped before it runs. This is useful if the 

710 different instances of a task X alter the same asset, and this asset 

711 is used by tasks downstream of task X. Note that depends_on_past 

712 is forced to True wherever wait_for_downstream is used. Also note that 

713 only tasks *immediately* downstream of the previous task instance are waited 

714 for; the statuses of any tasks further downstream are ignored. 

715 :param dag: a reference to the dag the task is attached to (if any) 

716 :param priority_weight: priority weight of this task against other task. 

717 This allows the executor to trigger higher priority tasks before 

718 others when things get backed up. Set priority_weight as a higher 

719 number for more important tasks. 

720 As not all database engines support 64-bit integers, values are capped with 32-bit. 

721 Valid range is from -2,147,483,648 to 2,147,483,647. 

722 :param weight_rule: weighting method used for the effective total 

723 priority weight of the task. Options are: 

724 ``{ downstream | upstream | absolute }`` default is ``downstream`` 

725 When set to ``downstream`` the effective weight of the task is the 

726 aggregate sum of all downstream descendants. As a result, upstream 

727 tasks will have higher weight and will be scheduled more aggressively 

728 when using positive weight values. This is useful when you have 

729 multiple dag run instances and desire to have all upstream tasks to 

730 complete for all runs before each dag can continue processing 

731 downstream tasks. When set to ``upstream`` the effective weight is the 

732 aggregate sum of all upstream ancestors. This is the opposite where 

733 downstream tasks have higher weight and will be scheduled more 

734 aggressively when using positive weight values. This is useful when you 

735 have multiple dag run instances and prefer to have each dag complete 

736 before starting upstream tasks of other dags. When set to 

737 ``absolute``, the effective weight is the exact ``priority_weight`` 

738 specified without additional weighting. You may want to do this when 

739 you know exactly what priority weight each task should have. 

740 Additionally, when set to ``absolute``, there is bonus effect of 

741 significantly speeding up the task creation process as for very large 

742 Dags. Options can be set as string or using the constants defined in 

743 the static class ``airflow.utils.WeightRule``. 

744 Irrespective of the weight rule, resulting priority values are capped with 32-bit. 

745 |experimental| 

746 Since 2.9.0, Airflow allows to define custom priority weight strategy, 

747 by creating a subclass of 

748 ``airflow.task.priority_strategy.PriorityWeightStrategy`` and registering 

749 in a plugin, then providing the class path or the class instance via 

750 ``weight_rule`` parameter. The custom priority weight strategy will be 

751 used to calculate the effective total priority weight of the task instance. 

752 :param queue: which queue to target when running this job. Not 

753 all executors implement queue management, the CeleryExecutor 

754 does support targeting specific queues. 

755 :param pool: the slot pool this task should run in, slot pools are a 

756 way to limit concurrency for certain tasks 

757 :param pool_slots: the number of pool slots this task should use (>= 1) 

758 Values less than 1 are not allowed. 

759 :param sla: DEPRECATED - The SLA feature is removed in Airflow 3.0, to be replaced with a 

760 new implementation in Airflow >=3.1. 

761 :param execution_timeout: max time allowed for the execution of 

762 this task instance, if it goes beyond it will raise and fail. 

763 :param on_failure_callback: a function or list of functions to be called when a task instance 

764 of this task fails. a context dictionary is passed as a single 

765 parameter to this function. Context contains references to related 

766 objects to the task instance and is documented under the macros 

767 section of the API. 

768 :param on_execute_callback: much like the ``on_failure_callback`` except 

769 that it is executed right before the task is executed. 

770 :param on_retry_callback: much like the ``on_failure_callback`` except 

771 that it is executed when retries occur. 

772 :param on_success_callback: much like the ``on_failure_callback`` except 

773 that it is executed when the task succeeds. 

774 :param on_skipped_callback: much like the ``on_failure_callback`` except 

775 that it is executed when skipped occur; this callback will be called only if AirflowSkipException get raised. 

776 Explicitly it is NOT called if a task is not started to be executed because of a preceding branching 

777 decision in the Dag or a trigger rule which causes execution to skip so that the task execution 

778 is never scheduled. 

779 :param pre_execute: a function to be called immediately before task 

780 execution, receiving a context dictionary; raising an exception will 

781 prevent the task from being executed. 

782 :param post_execute: a function to be called immediately after task 

783 execution, receiving a context dictionary and task result; raising an 

784 exception will prevent the task from succeeding. 

785 :param trigger_rule: defines the rule by which dependencies are applied 

786 for the task to get triggered. Options are: 

787 ``{ all_success | all_failed | all_done | all_skipped | one_success | one_done | 

788 one_failed | none_failed | none_failed_min_one_success | none_skipped | always}`` 

789 default is ``all_success``. Options can be set as string or 

790 using the constants defined in the static class 

791 ``airflow.utils.TriggerRule`` 

792 :param resources: A map of resource parameter names (the argument names of the 

793 Resources constructor) to their values. 

794 :param run_as_user: unix username to impersonate while running the task 

795 :param max_active_tis_per_dag: When set, a task will be able to limit the concurrent 

796 runs across logical_dates. 

797 :param max_active_tis_per_dagrun: When set, a task will be able to limit the concurrent 

798 task instances per Dag run. 

799 :param executor: Which executor to target when running this task. NOT YET SUPPORTED 

800 :param executor_config: Additional task-level configuration parameters that are 

801 interpreted by a specific executor. Parameters are namespaced by the name of 

802 executor. 

803 

804 **Example**: to run this task in a specific docker container through 

805 the KubernetesExecutor :: 

806 

807 MyOperator(..., executor_config={"KubernetesExecutor": {"image": "myCustomDockerImage"}}) 

808 

809 :param do_xcom_push: if True, an XCom is pushed containing the Operator's 

810 result 

811 :param multiple_outputs: if True and do_xcom_push is True, pushes multiple XComs, one for each 

812 key in the returned dictionary result. If False and do_xcom_push is True, pushes a single XCom. 

813 :param task_group: The TaskGroup to which the task should belong. This is typically provided when not 

814 using a TaskGroup as a context manager. 

815 :param doc: Add documentation or notes to your Task objects that is visible in 

816 Task Instance details View in the Webserver 

817 :param doc_md: Add documentation (in Markdown format) or notes to your Task objects 

818 that is visible in Task Instance details View in the Webserver 

819 :param doc_rst: Add documentation (in RST format) or notes to your Task objects 

820 that is visible in Task Instance details View in the Webserver 

821 :param doc_json: Add documentation (in JSON format) or notes to your Task objects 

822 that is visible in Task Instance details View in the Webserver 

823 :param doc_yaml: Add documentation (in YAML format) or notes to your Task objects 

824 that is visible in Task Instance details View in the Webserver 

825 :param task_display_name: The display name of the task which appears on the UI. 

826 :param logger_name: Name of the logger used by the Operator to emit logs. 

827 If set to `None` (default), the logger name will fall back to 

828 `airflow.task.operators.{class.__module__}.{class.__name__}` (e.g. HttpOperator will have 

829 *airflow.task.operators.airflow.providers.http.operators.http.HttpOperator* as logger). 

830 :param allow_nested_operators: if True, when an operator is executed within another one a warning message 

831 will be logged. If False, then an exception will be raised if the operator is badly used (e.g. nested 

832 within another one). In future releases of Airflow this parameter will be removed and an exception 

833 will always be thrown when operators are nested within each other (default is True). 

834 

835 **Example**: example of a bad operator mixin usage:: 

836 

837 @task(provide_context=True) 

838 def say_hello_world(**context): 

839 hello_world_task = BashOperator( 

840 task_id="hello_world_task", 

841 bash_command="python -c \"print('Hello, world!')\"", 

842 dag=dag, 

843 ) 

844 hello_world_task.execute(context) 

845 :param render_template_as_native_obj: If True, uses a Jinja ``NativeEnvironment`` 

846 to render templates as native Python types. If False, a Jinja 

847 ``Environment`` is used to render templates as string values. 

848 If None (default), inherits from the DAG setting. 

849 """ 

850 

851 task_id: str 

852 owner: str = DEFAULT_OWNER 

853 email: str | Sequence[str] | None = None 

854 email_on_retry: bool = DEFAULT_EMAIL_ON_RETRY 

855 email_on_failure: bool = DEFAULT_EMAIL_ON_FAILURE 

856 retries: int | None = DEFAULT_RETRIES 

857 retry_delay: timedelta = DEFAULT_RETRY_DELAY 

858 retry_exponential_backoff: float = 0 

859 max_retry_delay: timedelta | float | None = None 

860 start_date: datetime | None = None 

861 end_date: datetime | None = None 

862 depends_on_past: bool = False 

863 ignore_first_depends_on_past: bool = DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST 

864 wait_for_past_depends_before_skipping: bool = DEFAULT_WAIT_FOR_PAST_DEPENDS_BEFORE_SKIPPING 

865 wait_for_downstream: bool = False 

866 

867 # At execution_time this becomes a normal dict 

868 params: ParamsDict | dict = field(default_factory=ParamsDict) 

869 default_args: dict | None = None 

870 priority_weight: int = DEFAULT_PRIORITY_WEIGHT 

871 weight_rule: PriorityWeightStrategy | str = field(default=DEFAULT_WEIGHT_RULE) 

872 queue: str = DEFAULT_QUEUE 

873 pool: str = DEFAULT_POOL_NAME 

874 pool_slots: int = DEFAULT_POOL_SLOTS 

875 execution_timeout: timedelta | None = DEFAULT_TASK_EXECUTION_TIMEOUT 

876 on_execute_callback: Sequence[TaskStateChangeCallback] = () 

877 on_failure_callback: Sequence[TaskStateChangeCallback] = () 

878 on_success_callback: Sequence[TaskStateChangeCallback] = () 

879 on_retry_callback: Sequence[TaskStateChangeCallback] = () 

880 on_skipped_callback: Sequence[TaskStateChangeCallback] = () 

881 _pre_execute_hook: TaskPreExecuteHook | None = None 

882 _post_execute_hook: TaskPostExecuteHook | None = None 

883 trigger_rule: TriggerRule = DEFAULT_TRIGGER_RULE 

884 resources: dict[str, Any] | None = None 

885 run_as_user: str | None = None 

886 task_concurrency: int | None = None 

887 map_index_template: str | None = None 

888 max_active_tis_per_dag: int | None = None 

889 max_active_tis_per_dagrun: int | None = None 

890 executor: str | None = None 

891 executor_config: dict | None = None 

892 do_xcom_push: bool = True 

893 multiple_outputs: bool = False 

894 inlets: list[Any] = field(default_factory=list) 

895 outlets: list[Any] = field(default_factory=list) 

896 task_group: TaskGroup | None = None 

897 doc: str | None = None 

898 doc_md: str | None = None 

899 doc_json: str | None = None 

900 doc_yaml: str | None = None 

901 doc_rst: str | None = None 

902 _task_display_name: str | None = None 

903 logger_name: str | None = None 

904 allow_nested_operators: bool = True 

905 render_template_as_native_obj: bool | None = None 

906 

907 is_setup: bool = False 

908 is_teardown: bool = False 

909 

910 # TODO: Task-SDK: Make these ClassVar[]? 

911 template_fields: Collection[str] = () 

912 template_ext: Sequence[str] = () 

913 

914 template_fields_renderers: ClassVar[dict[str, str]] = {} 

915 

916 operator_extra_links: Collection[BaseOperatorLink] = () 

917 

918 # Defines the color in the UI 

919 ui_color: str = "#fff" 

920 ui_fgcolor: str = "#000" 

921 

922 partial: Callable[..., OperatorPartial] = _PartialDescriptor() # type: ignore 

923 

924 _dag: DAG | None = field(init=False, default=None) 

925 

926 # Make this optional so the type matches the one define in LoggingMixin 

927 _log_config_logger_name: str | None = field(default="airflow.task.operators", init=False) 

928 _logger_name: str | None = None 

929 

930 # The _serialized_fields are lazily loaded when get_serialized_fields() method is called 

931 __serialized_fields: ClassVar[frozenset[str] | None] = None 

932 

933 _comps: ClassVar[set[str]] = { 

934 "task_id", 

935 "dag_id", 

936 "owner", 

937 "email", 

938 "email_on_retry", 

939 "retry_delay", 

940 "retry_exponential_backoff", 

941 "max_retry_delay", 

942 "start_date", 

943 "end_date", 

944 "depends_on_past", 

945 "wait_for_downstream", 

946 "priority_weight", 

947 "execution_timeout", 

948 "has_on_execute_callback", 

949 "has_on_failure_callback", 

950 "has_on_success_callback", 

951 "has_on_retry_callback", 

952 "has_on_skipped_callback", 

953 "do_xcom_push", 

954 "multiple_outputs", 

955 "allow_nested_operators", 

956 "executor", 

957 } 

958 

959 # If True, the Rendered Template fields will be overwritten in DB after execution 

960 # This is useful for Taskflow decorators that modify the template fields during execution like 

961 # @task.bash decorator. 

962 overwrite_rtif_after_execution: bool = False 

963 

964 # If True then the class constructor was called 

965 __instantiated: bool = False 

966 # List of args as passed to `init()`, after apply_defaults() has been updated. Used to "recreate" the task 

967 # when mapping 

968 # Set via the metaclass 

969 __init_kwargs: dict[str, Any] = field(init=False) 

970 

971 # Set to True before calling execute method 

972 _lock_for_execution: bool = False 

973 

974 # Set to True for an operator instantiated by a mapped operator. 

975 __from_mapped: bool = False 

976 

977 start_trigger_args: StartTriggerArgs | None = None 

978 start_from_trigger: bool = False 

979 

980 # base list which includes all the attrs that don't need deep copy. 

981 _base_operator_shallow_copy_attrs: Final[tuple[str, ...]] = ( 

982 "user_defined_macros", 

983 "user_defined_filters", 

984 "params", 

985 ) 

986 

987 # each operator should override this class attr for shallow copy attrs. 

988 shallow_copy_attrs: Sequence[str] = () 

989 

990 def __setattr__(self: BaseOperator, key: str, value: Any): 

991 if converter := getattr(self, f"_convert_{key}", None): 

992 value = converter(value) 

993 super().__setattr__(key, value) 

994 if self.__from_mapped or self._lock_for_execution: 

995 return # Skip any custom behavior for validation and during execute. 

996 if key in self.__init_kwargs: 

997 self.__init_kwargs[key] = value 

998 if self.__instantiated and key in self.template_fields: 

999 # Resolve upstreams set by assigning an XComArg after initializing 

1000 # an operator, example: 

1001 # op = BashOperator() 

1002 # op.bash_command = "sleep 1" 

1003 self._set_xcomargs_dependency(key, value) 

1004 

1005 def __init__( 

1006 self, 

1007 *, 

1008 task_id: str, 

1009 owner: str = DEFAULT_OWNER, 

1010 email: str | Sequence[str] | None = None, 

1011 email_on_retry: bool = DEFAULT_EMAIL_ON_RETRY, 

1012 email_on_failure: bool = DEFAULT_EMAIL_ON_FAILURE, 

1013 retries: int | None = DEFAULT_RETRIES, 

1014 retry_delay: timedelta | float = DEFAULT_RETRY_DELAY, 

1015 retry_exponential_backoff: float = 0, 

1016 max_retry_delay: timedelta | float | None = None, 

1017 start_date: datetime | None = None, 

1018 end_date: datetime | None = None, 

1019 depends_on_past: bool = False, 

1020 ignore_first_depends_on_past: bool = DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST, 

1021 wait_for_past_depends_before_skipping: bool = DEFAULT_WAIT_FOR_PAST_DEPENDS_BEFORE_SKIPPING, 

1022 wait_for_downstream: bool = False, 

1023 dag: DAG | None = None, 

1024 params: collections.abc.MutableMapping[str, Any] | None = None, 

1025 default_args: dict | None = None, 

1026 priority_weight: int = DEFAULT_PRIORITY_WEIGHT, 

1027 weight_rule: str | PriorityWeightStrategy = DEFAULT_WEIGHT_RULE, 

1028 queue: str = DEFAULT_QUEUE, 

1029 pool: str | None = None, 

1030 pool_slots: int = DEFAULT_POOL_SLOTS, 

1031 sla: timedelta | None = None, 

1032 execution_timeout: timedelta | None = DEFAULT_TASK_EXECUTION_TIMEOUT, 

1033 on_execute_callback: None | TaskStateChangeCallback | Collection[TaskStateChangeCallback] = None, 

1034 on_failure_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = None, 

1035 on_success_callback: None | TaskStateChangeCallback | Collection[TaskStateChangeCallback] = None, 

1036 on_retry_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = None, 

1037 on_skipped_callback: None | TaskStateChangeCallback | Collection[TaskStateChangeCallback] = None, 

1038 pre_execute: TaskPreExecuteHook | None = None, 

1039 post_execute: TaskPostExecuteHook | None = None, 

1040 trigger_rule: str = DEFAULT_TRIGGER_RULE, 

1041 resources: dict[str, Any] | None = None, 

1042 run_as_user: str | None = None, 

1043 map_index_template: str | None = None, 

1044 max_active_tis_per_dag: int | None = None, 

1045 max_active_tis_per_dagrun: int | None = None, 

1046 executor: str | None = None, 

1047 executor_config: dict | None = None, 

1048 do_xcom_push: bool = True, 

1049 multiple_outputs: bool = False, 

1050 inlets: Any | None = None, 

1051 outlets: Any | None = None, 

1052 task_group: TaskGroup | None = None, 

1053 doc: str | None = None, 

1054 doc_md: str | None = None, 

1055 doc_json: str | None = None, 

1056 doc_yaml: str | None = None, 

1057 doc_rst: str | None = None, 

1058 task_display_name: str | None = None, 

1059 logger_name: str | None = None, 

1060 allow_nested_operators: bool = True, 

1061 render_template_as_native_obj: bool | None = None, 

1062 **kwargs: Any, 

1063 ): 

1064 # Note: Metaclass handles passing in the Dag/TaskGroup from active context manager, if any 

1065 

1066 # Only apply task_group prefix if this operator was not created from a mapped operator 

1067 # Mapped operators already have the prefix applied during their creation 

1068 if task_group and not self.__from_mapped: 

1069 self.task_id = task_group.child_id(task_id) 

1070 task_group.add(self) 

1071 else: 

1072 self.task_id = task_id 

1073 

1074 super().__init__() 

1075 self.task_group = task_group 

1076 

1077 kwargs.pop("_airflow_mapped_validation_only", None) 

1078 if kwargs: 

1079 raise TypeError( 

1080 f"Invalid arguments were passed to {self.__class__.__name__} (task_id: {task_id}). " 

1081 f"Invalid arguments were:\n**kwargs: {redact(kwargs)}", 

1082 ) 

1083 validate_key(self.task_id) 

1084 

1085 self.owner = owner 

1086 self.email = email 

1087 self.email_on_retry = email_on_retry 

1088 self.email_on_failure = email_on_failure 

1089 

1090 if email is not None: 

1091 warnings.warn( 

1092 "Setting email on a task is deprecated; please migrate to SmtpNotifier.", 

1093 RemovedInAirflow4Warning, 

1094 stacklevel=2, 

1095 ) 

1096 if email and email_on_retry is not None: 

1097 warnings.warn( 

1098 "Setting email_on_retry on a task is deprecated; please migrate to SmtpNotifier.", 

1099 RemovedInAirflow4Warning, 

1100 stacklevel=2, 

1101 ) 

1102 if email and email_on_failure is not None: 

1103 warnings.warn( 

1104 "Setting email_on_failure on a task is deprecated; please migrate to SmtpNotifier.", 

1105 RemovedInAirflow4Warning, 

1106 stacklevel=2, 

1107 ) 

1108 

1109 if execution_timeout is not None and not isinstance(execution_timeout, timedelta): 

1110 raise ValueError( 

1111 f"execution_timeout must be timedelta object but passed as type: {type(execution_timeout)}" 

1112 ) 

1113 self.execution_timeout = execution_timeout 

1114 

1115 self.on_execute_callback = _collect_from_input(on_execute_callback) 

1116 self.on_failure_callback = _collect_from_input(on_failure_callback) 

1117 self.on_success_callback = _collect_from_input(on_success_callback) 

1118 self.on_retry_callback = _collect_from_input(on_retry_callback) 

1119 self.on_skipped_callback = _collect_from_input(on_skipped_callback) 

1120 self._pre_execute_hook = pre_execute 

1121 self._post_execute_hook = post_execute 

1122 

1123 self.start_date = timezone.convert_to_utc(start_date) 

1124 self.end_date = timezone.convert_to_utc(end_date) 

1125 self.executor = executor 

1126 self.executor_config = executor_config or {} 

1127 self.run_as_user = run_as_user 

1128 # TODO: 

1129 # self.retries = parse_retries(retries) 

1130 self.retries = retries 

1131 self.queue = queue 

1132 self.pool = DEFAULT_POOL_NAME if pool is None else pool 

1133 self.pool_slots = pool_slots 

1134 if self.pool_slots < 1: 

1135 dag_str = f" in dag {dag.dag_id}" if dag else "" 

1136 raise ValueError(f"pool slots for {self.task_id}{dag_str} cannot be less than 1") 

1137 if sla is not None: 

1138 warnings.warn( 

1139 "The SLA feature is removed in Airflow 3.0, replaced with Deadline Alerts in >=3.1", 

1140 stacklevel=2, 

1141 ) 

1142 

1143 try: 

1144 TriggerRule(trigger_rule) 

1145 except ValueError: 

1146 raise ValueError( 

1147 f"The trigger_rule must be one of {[rule.value for rule in TriggerRule]}," 

1148 f"'{dag.dag_id if dag else ''}.{task_id}'; received '{trigger_rule}'." 

1149 ) 

1150 

1151 self.trigger_rule: TriggerRule = TriggerRule(trigger_rule) 

1152 

1153 self.depends_on_past: bool = depends_on_past 

1154 self.ignore_first_depends_on_past: bool = ignore_first_depends_on_past 

1155 self.wait_for_past_depends_before_skipping: bool = wait_for_past_depends_before_skipping 

1156 self.wait_for_downstream: bool = wait_for_downstream 

1157 if wait_for_downstream: 

1158 self.depends_on_past = True 

1159 

1160 # Converted by setattr 

1161 self.retry_delay = retry_delay # type: ignore[assignment] 

1162 self.retry_exponential_backoff = retry_exponential_backoff 

1163 if max_retry_delay is not None: 

1164 self.max_retry_delay = max_retry_delay 

1165 

1166 self.resources = resources 

1167 

1168 self.params = ParamsDict(params) 

1169 

1170 self.priority_weight = priority_weight 

1171 self.weight_rule = weight_rule 

1172 

1173 self.max_active_tis_per_dag: int | None = max_active_tis_per_dag 

1174 self.max_active_tis_per_dagrun: int | None = max_active_tis_per_dagrun 

1175 self.do_xcom_push: bool = do_xcom_push 

1176 self.map_index_template: str | None = map_index_template 

1177 self.multiple_outputs: bool = multiple_outputs 

1178 

1179 self.doc_md = doc_md 

1180 self.doc_json = doc_json 

1181 self.doc_yaml = doc_yaml 

1182 self.doc_rst = doc_rst 

1183 self.doc = doc 

1184 

1185 self._task_display_name = task_display_name 

1186 

1187 self.allow_nested_operators = allow_nested_operators 

1188 

1189 self.render_template_as_native_obj = render_template_as_native_obj 

1190 

1191 self._logger_name = logger_name 

1192 

1193 # Lineage 

1194 self.inlets = _collect_from_input(inlets) 

1195 self.outlets = _collect_from_input(outlets) 

1196 

1197 if isinstance(self.template_fields, str): 

1198 warnings.warn( 

1199 f"The `template_fields` value for {self.task_type} is a string " 

1200 "but should be a list or tuple of string. Wrapping it in a list for execution. " 

1201 f"Please update {self.task_type} accordingly.", 

1202 UserWarning, 

1203 stacklevel=2, 

1204 ) 

1205 self.template_fields = [self.template_fields] 

1206 

1207 self.is_setup = False 

1208 self.is_teardown = False 

1209 

1210 if SetupTeardownContext.active: 

1211 SetupTeardownContext.update_context_map(self) 

1212 

1213 # We set self.dag right at the end as `_convert_dag` calls `dag.add_task` for us, and we need all the 

1214 # other properties to be set at that point 

1215 if dag is not None: 

1216 self.dag = dag 

1217 

1218 validate_instance_args(self, BASEOPERATOR_ARGS_EXPECTED_TYPES) 

1219 

1220 # Ensure priority_weight is within the valid range 

1221 self.priority_weight = db_safe_priority(self.priority_weight) 

1222 

1223 def __eq__(self, other): 

1224 if type(self) is type(other): 

1225 # Use getattr() instead of __dict__ as __dict__ doesn't return 

1226 # correct values for properties. 

1227 return all(getattr(self, c, None) == getattr(other, c, None) for c in self._comps) 

1228 return False 

1229 

1230 def __ne__(self, other): 

1231 return not self == other 

1232 

1233 def __hash__(self): 

1234 hash_components = [type(self)] 

1235 for component in self._comps: 

1236 val = getattr(self, component, None) 

1237 try: 

1238 hash(val) 

1239 hash_components.append(val) 

1240 except TypeError: 

1241 hash_components.append(repr(val)) 

1242 return hash(tuple(hash_components)) 

1243 

1244 # /Composing Operators --------------------------------------------- 

1245 

1246 def __gt__(self, other): 

1247 """ 

1248 Return [Operator] > [Outlet]. 

1249 

1250 If other is an attr annotated object it is set as an outlet of this Operator. 

1251 """ 

1252 if not isinstance(other, Iterable): 

1253 other = [other] 

1254 

1255 for obj in other: 

1256 if not attrs.has(obj): 

1257 raise TypeError(f"Left hand side ({obj}) is not an outlet") 

1258 self.add_outlets(other) 

1259 

1260 return self 

1261 

1262 def __lt__(self, other): 

1263 """ 

1264 Return [Inlet] > [Operator] or [Operator] < [Inlet]. 

1265 

1266 If other is an attr annotated object it is set as an inlet to this operator. 

1267 """ 

1268 if not isinstance(other, Iterable): 

1269 other = [other] 

1270 

1271 for obj in other: 

1272 if not attrs.has(obj): 

1273 raise TypeError(f"{obj} cannot be an inlet") 

1274 self.add_inlets(other) 

1275 

1276 return self 

1277 

1278 def __deepcopy__(self, memo: dict[int, Any]): 

1279 # Hack sorting double chained task lists by task_id to avoid hitting 

1280 # max_depth on deepcopy operations. 

1281 sys.setrecursionlimit(5000) # TODO fix this in a better way 

1282 

1283 cls = self.__class__ 

1284 result = cls.__new__(cls) 

1285 memo[id(self)] = result 

1286 

1287 shallow_copy = tuple(cls.shallow_copy_attrs) + cls._base_operator_shallow_copy_attrs 

1288 

1289 for k, v_org in self.__dict__.items(): 

1290 if k not in shallow_copy: 

1291 v = copy.deepcopy(v_org, memo) 

1292 else: 

1293 v = copy.copy(v_org) 

1294 

1295 # Bypass any setters, and set it on the object directly. This works since we are cloning ourself so 

1296 # we know the type is already fine 

1297 result.__dict__[k] = v 

1298 return result 

1299 

1300 def __getstate__(self): 

1301 state = dict(self.__dict__) 

1302 if "_log" in state: 

1303 del state["_log"] 

1304 

1305 return state 

1306 

1307 def __setstate__(self, state): 

1308 self.__dict__ = state 

1309 

1310 def add_inlets(self, inlets: Iterable[Any]): 

1311 """Set inlets to this operator.""" 

1312 self.inlets.extend(inlets) 

1313 

1314 def add_outlets(self, outlets: Iterable[Any]): 

1315 """Define the outlets of this operator.""" 

1316 self.outlets.extend(outlets) 

1317 

1318 def get_dag(self) -> DAG | None: 

1319 return self._dag 

1320 

1321 @property 

1322 def dag(self) -> DAG: 

1323 """Returns the Operator's Dag if set, otherwise raises an error.""" 

1324 if dag := self._dag: 

1325 return dag 

1326 raise RuntimeError(f"Operator {self} has not been assigned to a Dag yet") 

1327 

1328 @dag.setter 

1329 def dag(self, dag: DAG | None) -> None: 

1330 """Operators can be assigned to one Dag, one time. Repeat assignments to that same Dag are ok.""" 

1331 self._dag = dag 

1332 

1333 def _convert__dag(self, dag: DAG | None) -> DAG | None: 

1334 # Called automatically by __setattr__ method 

1335 from airflow.sdk.definitions.dag import DAG 

1336 

1337 if dag is None: 

1338 return dag 

1339 

1340 if not isinstance(dag, DAG): 

1341 raise TypeError(f"Expected dag; received {dag.__class__.__name__}") 

1342 if self._dag is not None and self._dag is not dag: 

1343 raise ValueError(f"The dag assigned to {self} can not be changed.") 

1344 

1345 if self.__from_mapped: 

1346 pass # Don't add to dag -- the mapped task takes the place. 

1347 elif dag.task_dict.get(self.task_id) is not self: 

1348 dag.add_task(self) 

1349 return dag 

1350 

1351 @staticmethod 

1352 def _convert_retries(retries: Any) -> int | None: 

1353 if retries is None: 

1354 return 0 

1355 if type(retries) == int: # noqa: E721 

1356 return retries 

1357 try: 

1358 parsed_retries = int(retries) 

1359 except (TypeError, ValueError): 

1360 raise TypeError(f"'retries' type must be int, not {type(retries).__name__}") 

1361 return parsed_retries 

1362 

1363 @staticmethod 

1364 def _convert_timedelta(value: float | timedelta | None) -> timedelta | None: 

1365 if value is None or isinstance(value, timedelta): 

1366 return value 

1367 return timedelta(seconds=value) 

1368 

1369 _convert_retry_delay = _convert_timedelta 

1370 _convert_max_retry_delay = _convert_timedelta 

1371 

1372 @staticmethod 

1373 def _convert_resources(resources: dict[str, Any] | None) -> Resources | None: 

1374 if resources is None: 

1375 return None 

1376 

1377 from airflow.sdk.definitions.operator_resources import Resources 

1378 

1379 if isinstance(resources, Resources): 

1380 return resources 

1381 

1382 return Resources(**resources) 

1383 

1384 def _convert_is_setup(self, value: bool) -> bool: 

1385 """ 

1386 Setter for is_setup property. 

1387 

1388 :meta private: 

1389 """ 

1390 if self.is_teardown and value: 

1391 raise ValueError(f"Cannot mark task '{self.task_id}' as setup; task is already a teardown.") 

1392 return value 

1393 

1394 def _convert_is_teardown(self, value: bool) -> bool: 

1395 if self.is_setup and value: 

1396 raise ValueError(f"Cannot mark task '{self.task_id}' as teardown; task is already a setup.") 

1397 return value 

1398 

1399 @property 

1400 def task_display_name(self) -> str: 

1401 return self._task_display_name or self.task_id 

1402 

1403 def has_dag(self): 

1404 """Return True if the Operator has been assigned to a Dag.""" 

1405 return self._dag is not None 

1406 

1407 def _set_xcomargs_dependencies(self) -> None: 

1408 from airflow.sdk.definitions.xcom_arg import XComArg 

1409 

1410 for f in self.template_fields: 

1411 arg = getattr(self, f, NOTSET) 

1412 if arg is not NOTSET: 

1413 XComArg.apply_upstream_relationship(self, arg) 

1414 

1415 def _set_xcomargs_dependency(self, field: str, newvalue: Any) -> None: 

1416 """ 

1417 Resolve upstream dependencies of a task. 

1418 

1419 In this way passing an ``XComArg`` as value for a template field 

1420 will result in creating upstream relation between two tasks. 

1421 

1422 **Example**: :: 

1423 

1424 with DAG(...): 

1425 generate_content = GenerateContentOperator(task_id="generate_content") 

1426 send_email = EmailOperator(..., html_content=generate_content.output) 

1427 

1428 # This is equivalent to 

1429 with DAG(...): 

1430 generate_content = GenerateContentOperator(task_id="generate_content") 

1431 send_email = EmailOperator( 

1432 ..., html_content="{{ task_instance.xcom_pull('generate_content') }}" 

1433 ) 

1434 generate_content >> send_email 

1435 

1436 """ 

1437 from airflow.sdk.definitions.xcom_arg import XComArg 

1438 

1439 if field not in self.template_fields: 

1440 return 

1441 XComArg.apply_upstream_relationship(self, newvalue) 

1442 

1443 def on_kill(self) -> None: 

1444 """ 

1445 Override this method to clean up subprocesses when a task instance gets killed. 

1446 

1447 Any use of the threading, subprocess or multiprocessing module within an 

1448 operator needs to be cleaned up, or it will leave ghost processes behind. 

1449 """ 

1450 

1451 def __repr__(self): 

1452 return f"<Task({self.task_type}): {self.task_id}>" 

1453 

1454 @property 

1455 def operator_class(self) -> type[BaseOperator]: # type: ignore[override] 

1456 return self.__class__ 

1457 

1458 @property 

1459 def task_type(self) -> str: 

1460 """@property: type of the task.""" 

1461 return self.__class__.__name__ 

1462 

1463 @property 

1464 def operator_name(self) -> str: 

1465 """@property: use a more friendly display name for the operator, if set.""" 

1466 try: 

1467 return self.custom_operator_name # type: ignore 

1468 except AttributeError: 

1469 return self.task_type 

1470 

1471 @property 

1472 def roots(self) -> list[BaseOperator]: 

1473 """Required by DAGNode.""" 

1474 return [self] 

1475 

1476 @property 

1477 def leaves(self) -> list[BaseOperator]: 

1478 """Required by DAGNode.""" 

1479 return [self] 

1480 

1481 @property 

1482 def output(self) -> XComArg: 

1483 """Returns reference to XCom pushed by current operator.""" 

1484 from airflow.sdk.definitions.xcom_arg import XComArg 

1485 

1486 return XComArg(operator=self) 

1487 

1488 @classmethod 

1489 def get_serialized_fields(cls): 

1490 """Stringified Dags and operators contain exactly these fields.""" 

1491 if not cls.__serialized_fields: 

1492 from airflow.sdk.definitions._internal.contextmanager import DagContext 

1493 

1494 # make sure the following "fake" task is not added to current active 

1495 # dag in context, otherwise, it will result in 

1496 # `RuntimeError: dictionary changed size during iteration` 

1497 # Exception in SerializedDAG.serialize_dag() call. 

1498 DagContext.push(None) 

1499 cls.__serialized_fields = frozenset( 

1500 vars(BaseOperator(task_id="test")).keys() 

1501 - { 

1502 "upstream_task_ids", 

1503 "default_args", 

1504 "dag", 

1505 "_dag", 

1506 "label", 

1507 "_BaseOperator__instantiated", 

1508 "_BaseOperator__init_kwargs", 

1509 "_BaseOperator__from_mapped", 

1510 "on_failure_fail_dagrun", 

1511 "task_group", 

1512 "_task_type", 

1513 "operator_extra_links", 

1514 "on_execute_callback", 

1515 "on_failure_callback", 

1516 "on_success_callback", 

1517 "on_retry_callback", 

1518 "on_skipped_callback", 

1519 } 

1520 | { # Class level defaults, or `@property` need to be added to this list 

1521 "start_date", 

1522 "end_date", 

1523 "task_type", 

1524 "ui_color", 

1525 "ui_fgcolor", 

1526 "template_ext", 

1527 "template_fields", 

1528 "template_fields_renderers", 

1529 "params", 

1530 "is_setup", 

1531 "is_teardown", 

1532 "on_failure_fail_dagrun", 

1533 "map_index_template", 

1534 "start_trigger_args", 

1535 "_needs_expansion", 

1536 "start_from_trigger", 

1537 "max_retry_delay", 

1538 "has_on_execute_callback", 

1539 "has_on_failure_callback", 

1540 "has_on_success_callback", 

1541 "has_on_retry_callback", 

1542 "has_on_skipped_callback", 

1543 } 

1544 ) 

1545 DagContext.pop() 

1546 

1547 return cls.__serialized_fields 

1548 

1549 def prepare_for_execution(self) -> Self: 

1550 """Lock task for execution to disable custom action in ``__setattr__`` and return a copy.""" 

1551 other = copy.copy(self) 

1552 other._lock_for_execution = True 

1553 return other 

1554 

1555 def serialize_for_task_group(self) -> tuple[DagAttributeTypes, Any]: 

1556 """Serialize; required by DAGNode.""" 

1557 from airflow.sdk.api.datamodels._generated import DagAttributeTypes 

1558 

1559 return DagAttributeTypes.OP, self.task_id 

1560 

1561 def unmap(self, resolve: None | Mapping[str, Any]) -> Self: 

1562 """ 

1563 Get the "normal" operator from the current operator. 

1564 

1565 Since a BaseOperator is not mapped to begin with, this simply returns 

1566 the original operator. 

1567 

1568 :meta private: 

1569 """ 

1570 return self 

1571 

1572 def expand_start_trigger_args(self, *, context: Context) -> StartTriggerArgs | None: 

1573 """ 

1574 Get the start_trigger_args value of the current abstract operator. 

1575 

1576 Since a BaseOperator is not mapped to begin with, this simply returns 

1577 the original value of start_trigger_args. 

1578 

1579 :meta private: 

1580 """ 

1581 return self.start_trigger_args 

1582 

1583 def render_template_fields( 

1584 self, 

1585 context: Context, 

1586 jinja_env: jinja2.Environment | None = None, 

1587 ) -> None: 

1588 """ 

1589 Template all attributes listed in *self.template_fields*. 

1590 

1591 This mutates the attributes in-place and is irreversible. 

1592 

1593 :param context: Context dict with values to apply on content. 

1594 :param jinja_env: Jinja's environment to use for rendering. 

1595 """ 

1596 if not jinja_env: 

1597 jinja_env = self.get_template_env() 

1598 self._do_render_template_fields(self, self.template_fields, context, jinja_env, set()) 

1599 

1600 def pre_execute(self, context: Any): 

1601 """Execute right before self.execute() is called.""" 

1602 

1603 def execute(self, context: Context) -> Any: 

1604 """ 

1605 Derive when creating an operator. 

1606 

1607 The main method to execute the task. Context is the same dictionary used 

1608 as when rendering jinja templates. 

1609 

1610 Refer to get_template_context for more context. 

1611 """ 

1612 raise NotImplementedError() 

1613 

1614 def post_execute(self, context: Any, result: Any = None): 

1615 """ 

1616 Execute right after self.execute() is called. 

1617 

1618 It is passed the execution context and any results returned by the operator. 

1619 """ 

1620 

1621 def defer( 

1622 self, 

1623 *, 

1624 trigger: BaseTrigger, 

1625 method_name: str, 

1626 kwargs: dict[str, Any] | None = None, 

1627 timeout: timedelta | int | float | None = None, 

1628 ) -> NoReturn: 

1629 """ 

1630 Mark this Operator "deferred", suspending its execution until the provided trigger fires an event. 

1631 

1632 This is achieved by raising a special exception (TaskDeferred) 

1633 which is caught in the main _execute_task wrapper. Triggers can send execution back to task or end 

1634 the task instance directly. If the trigger will end the task instance itself, ``method_name`` should 

1635 be None; otherwise, provide the name of the method that should be used when resuming execution in 

1636 the task. 

1637 """ 

1638 from airflow.sdk.exceptions import TaskDeferred 

1639 

1640 raise TaskDeferred(trigger=trigger, method_name=method_name, kwargs=kwargs, timeout=timeout) 

1641 

1642 def resume_execution(self, next_method: str, next_kwargs: dict[str, Any] | None, context: Context): 

1643 """Entrypoint method called by the Task Runner (instead of execute) when this task is resumed.""" 

1644 from airflow.sdk.exceptions import TaskDeferralError, TaskDeferralTimeout 

1645 

1646 if next_kwargs is None: 

1647 next_kwargs = {} 

1648 # __fail__ is a special signal value for next_method that indicates 

1649 # this task was scheduled specifically to fail. 

1650 

1651 if next_method == TRIGGER_FAIL_REPR: 

1652 next_kwargs = next_kwargs or {} 

1653 traceback = next_kwargs.get("traceback") 

1654 if traceback is not None: 

1655 self.log.error("Trigger failed:\n%s", "\n".join(traceback)) 

1656 if (error := next_kwargs.get("error", "Unknown")) == TriggerFailureReason.TRIGGER_TIMEOUT: 

1657 raise TaskDeferralTimeout(error) 

1658 raise TaskDeferralError(error) 

1659 # Grab the callable off the Operator/Task and add in any kwargs 

1660 execute_callable = getattr(self, next_method) 

1661 return execute_callable(context, **next_kwargs) 

1662 

1663 def dry_run(self) -> None: 

1664 """Perform dry run for the operator - just render template fields.""" 

1665 self.log.info("Dry run") 

1666 for f in self.template_fields: 

1667 try: 

1668 content = getattr(self, f) 

1669 except AttributeError: 

1670 raise AttributeError( 

1671 f"{f!r} is configured as a template field " 

1672 f"but {self.task_type} does not have this attribute." 

1673 ) 

1674 

1675 if content and isinstance(content, str): 

1676 self.log.info("Rendering template for %s", f) 

1677 self.log.info(content) 

1678 

1679 @property 

1680 def has_on_execute_callback(self) -> bool: 

1681 """Return True if the task has execute callbacks.""" 

1682 return bool(self.on_execute_callback) 

1683 

1684 @property 

1685 def has_on_failure_callback(self) -> bool: 

1686 """Return True if the task has failure callbacks.""" 

1687 return bool(self.on_failure_callback) 

1688 

1689 @property 

1690 def has_on_success_callback(self) -> bool: 

1691 """Return True if the task has success callbacks.""" 

1692 return bool(self.on_success_callback) 

1693 

1694 @property 

1695 def has_on_retry_callback(self) -> bool: 

1696 """Return True if the task has retry callbacks.""" 

1697 return bool(self.on_retry_callback) 

1698 

1699 @property 

1700 def has_on_skipped_callback(self) -> bool: 

1701 """Return True if the task has skipped callbacks.""" 

1702 return bool(self.on_skipped_callback) 

1703 

1704 

1705class BaseAsyncOperator(BaseOperator): 

1706 """ 

1707 Base class for async-capable operators. 

1708 

1709 As opposed to deferred operators which are executed on the triggerer, async operators are executed 

1710 on the worker. 

1711 """ 

1712 

1713 @property 

1714 def is_async(self) -> bool: 

1715 return True 

1716 

1717 async def aexecute(self, context): 

1718 """Async version of execute(). Subclasses should implement this.""" 

1719 raise NotImplementedError() 

1720 

1721 def execute(self, context): 

1722 """Run `aexecute()` inside an event loop.""" 

1723 with event_loop() as loop: 

1724 if self.execution_timeout: 

1725 return loop.run_until_complete( 

1726 asyncio.wait_for( 

1727 self.aexecute(context), 

1728 timeout=self.execution_timeout.total_seconds(), 

1729 ) 

1730 ) 

1731 return loop.run_until_complete(self.aexecute(context)) 

1732 

1733 

1734def chain(*tasks: DependencyMixin | Sequence[DependencyMixin]) -> None: 

1735 r""" 

1736 Given a number of tasks, builds a dependency chain. 

1737 

1738 This function accepts values of BaseOperator (aka tasks), EdgeModifiers (aka Labels), XComArg, TaskGroups, 

1739 or lists containing any mix of these types (or a mix in the same list). If you want to chain between two 

1740 lists you must ensure they have the same length. 

1741 

1742 Using classic operators/sensors: 

1743 

1744 .. code-block:: python 

1745 

1746 chain(t1, [t2, t3], [t4, t5], t6) 

1747 

1748 is equivalent to:: 

1749 

1750 / -> t2 -> t4 \ 

1751 t1 -> t6 

1752 \ -> t3 -> t5 / 

1753 

1754 .. code-block:: python 

1755 

1756 t1.set_downstream(t2) 

1757 t1.set_downstream(t3) 

1758 t2.set_downstream(t4) 

1759 t3.set_downstream(t5) 

1760 t4.set_downstream(t6) 

1761 t5.set_downstream(t6) 

1762 

1763 Using task-decorated functions aka XComArgs: 

1764 

1765 .. code-block:: python 

1766 

1767 chain(x1(), [x2(), x3()], [x4(), x5()], x6()) 

1768 

1769 is equivalent to:: 

1770 

1771 / -> x2 -> x4 \ 

1772 x1 -> x6 

1773 \ -> x3 -> x5 / 

1774 

1775 .. code-block:: python 

1776 

1777 x1 = x1() 

1778 x2 = x2() 

1779 x3 = x3() 

1780 x4 = x4() 

1781 x5 = x5() 

1782 x6 = x6() 

1783 x1.set_downstream(x2) 

1784 x1.set_downstream(x3) 

1785 x2.set_downstream(x4) 

1786 x3.set_downstream(x5) 

1787 x4.set_downstream(x6) 

1788 x5.set_downstream(x6) 

1789 

1790 Using TaskGroups: 

1791 

1792 .. code-block:: python 

1793 

1794 chain(t1, task_group1, task_group2, t2) 

1795 

1796 t1.set_downstream(task_group1) 

1797 task_group1.set_downstream(task_group2) 

1798 task_group2.set_downstream(t2) 

1799 

1800 

1801 It is also possible to mix between classic operator/sensor, EdgeModifiers, XComArg, and TaskGroups: 

1802 

1803 .. code-block:: python 

1804 

1805 chain(t1, [Label("branch one"), Label("branch two")], [x1(), x2()], task_group1, x3()) 

1806 

1807 is equivalent to:: 

1808 

1809 / "branch one" -> x1 \ 

1810 t1 -> task_group1 -> x3 

1811 \ "branch two" -> x2 / 

1812 

1813 .. code-block:: python 

1814 

1815 x1 = x1() 

1816 x2 = x2() 

1817 x3 = x3() 

1818 label1 = Label("branch one") 

1819 label2 = Label("branch two") 

1820 t1.set_downstream(label1) 

1821 label1.set_downstream(x1) 

1822 t2.set_downstream(label2) 

1823 label2.set_downstream(x2) 

1824 x1.set_downstream(task_group1) 

1825 x2.set_downstream(task_group1) 

1826 task_group1.set_downstream(x3) 

1827 

1828 # or 

1829 

1830 x1 = x1() 

1831 x2 = x2() 

1832 x3 = x3() 

1833 t1.set_downstream(x1, edge_modifier=Label("branch one")) 

1834 t1.set_downstream(x2, edge_modifier=Label("branch two")) 

1835 x1.set_downstream(task_group1) 

1836 x2.set_downstream(task_group1) 

1837 task_group1.set_downstream(x3) 

1838 

1839 

1840 :param tasks: Individual and/or list of tasks, EdgeModifiers, XComArgs, or TaskGroups to set dependencies 

1841 """ 

1842 for up_task, down_task in zip(tasks, tasks[1:]): 

1843 if isinstance(up_task, DependencyMixin): 

1844 up_task.set_downstream(down_task) 

1845 continue 

1846 if isinstance(down_task, DependencyMixin): 

1847 down_task.set_upstream(up_task) 

1848 continue 

1849 if not isinstance(up_task, Sequence) or not isinstance(down_task, Sequence): 

1850 raise TypeError(f"Chain not supported between instances of {type(up_task)} and {type(down_task)}") 

1851 up_task_list = up_task 

1852 down_task_list = down_task 

1853 if len(up_task_list) != len(down_task_list): 

1854 raise ValueError( 

1855 f"Chain not supported for different length Iterable. " 

1856 f"Got {len(up_task_list)} and {len(down_task_list)}." 

1857 ) 

1858 for up_t, down_t in zip(up_task_list, down_task_list): 

1859 up_t.set_downstream(down_t) 

1860 

1861 

1862def cross_downstream( 

1863 from_tasks: Sequence[DependencyMixin], 

1864 to_tasks: DependencyMixin | Sequence[DependencyMixin], 

1865): 

1866 r""" 

1867 Set downstream dependencies for all tasks in from_tasks to all tasks in to_tasks. 

1868 

1869 Using classic operators/sensors: 

1870 

1871 .. code-block:: python 

1872 

1873 cross_downstream(from_tasks=[t1, t2, t3], to_tasks=[t4, t5, t6]) 

1874 

1875 is equivalent to:: 

1876 

1877 t1 ---> t4 

1878 \ / 

1879 t2 -X -> t5 

1880 / \ 

1881 t3 ---> t6 

1882 

1883 .. code-block:: python 

1884 

1885 t1.set_downstream(t4) 

1886 t1.set_downstream(t5) 

1887 t1.set_downstream(t6) 

1888 t2.set_downstream(t4) 

1889 t2.set_downstream(t5) 

1890 t2.set_downstream(t6) 

1891 t3.set_downstream(t4) 

1892 t3.set_downstream(t5) 

1893 t3.set_downstream(t6) 

1894 

1895 Using task-decorated functions aka XComArgs: 

1896 

1897 .. code-block:: python 

1898 

1899 cross_downstream(from_tasks=[x1(), x2(), x3()], to_tasks=[x4(), x5(), x6()]) 

1900 

1901 is equivalent to:: 

1902 

1903 x1 ---> x4 

1904 \ / 

1905 x2 -X -> x5 

1906 / \ 

1907 x3 ---> x6 

1908 

1909 .. code-block:: python 

1910 

1911 x1 = x1() 

1912 x2 = x2() 

1913 x3 = x3() 

1914 x4 = x4() 

1915 x5 = x5() 

1916 x6 = x6() 

1917 x1.set_downstream(x4) 

1918 x1.set_downstream(x5) 

1919 x1.set_downstream(x6) 

1920 x2.set_downstream(x4) 

1921 x2.set_downstream(x5) 

1922 x2.set_downstream(x6) 

1923 x3.set_downstream(x4) 

1924 x3.set_downstream(x5) 

1925 x3.set_downstream(x6) 

1926 

1927 It is also possible to mix between classic operator/sensor and XComArg tasks: 

1928 

1929 .. code-block:: python 

1930 

1931 cross_downstream(from_tasks=[t1, x2(), t3], to_tasks=[x1(), t2, x3()]) 

1932 

1933 is equivalent to:: 

1934 

1935 t1 ---> x1 

1936 \ / 

1937 x2 -X -> t2 

1938 / \ 

1939 t3 ---> x3 

1940 

1941 .. code-block:: python 

1942 

1943 x1 = x1() 

1944 x2 = x2() 

1945 x3 = x3() 

1946 t1.set_downstream(x1) 

1947 t1.set_downstream(t2) 

1948 t1.set_downstream(x3) 

1949 x2.set_downstream(x1) 

1950 x2.set_downstream(t2) 

1951 x2.set_downstream(x3) 

1952 t3.set_downstream(x1) 

1953 t3.set_downstream(t2) 

1954 t3.set_downstream(x3) 

1955 

1956 :param from_tasks: List of tasks or XComArgs to start from. 

1957 :param to_tasks: List of tasks or XComArgs to set as downstream dependencies. 

1958 """ 

1959 for task in from_tasks: 

1960 task.set_downstream(to_tasks) 

1961 

1962 

1963def chain_linear(*elements: DependencyMixin | Sequence[DependencyMixin]): 

1964 """ 

1965 Simplify task dependency definition. 

1966 

1967 E.g.: suppose you want precedence like so:: 

1968 

1969 ╭─op2─╮ ╭─op4─╮ 

1970 op1─┤ ├─├─op5─┤─op7 

1971 ╰-op3─╯ ╰-op6─╯ 

1972 

1973 Then you can accomplish like so:: 

1974 

1975 chain_linear(op1, [op2, op3], [op4, op5, op6], op7) 

1976 

1977 :param elements: a list of operators / lists of operators 

1978 """ 

1979 if not elements: 

1980 raise ValueError("No tasks provided; nothing to do.") 

1981 prev_elem = None 

1982 deps_set = False 

1983 for curr_elem in elements: 

1984 if isinstance(curr_elem, EdgeModifier): 

1985 raise ValueError("Labels are not supported by chain_linear") 

1986 if prev_elem is not None: 

1987 for task in prev_elem: 

1988 task >> curr_elem 

1989 if not deps_set: 

1990 deps_set = True 

1991 prev_elem = [curr_elem] if isinstance(curr_elem, DependencyMixin) else curr_elem 

1992 if not deps_set: 

1993 raise ValueError("No dependencies were set. Did you forget to expand with `*`?")