1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements. See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership. The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License. You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied. See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18from __future__ import annotations
19
20import abc
21import asyncio
22import collections.abc
23import contextlib
24import copy
25import inspect
26import sys
27import warnings
28from asyncio import AbstractEventLoop
29from collections.abc import Callable, Collection, Generator, Iterable, Mapping, Sequence
30from contextvars import ContextVar
31from dataclasses import dataclass, field
32from datetime import datetime, timedelta
33from enum import Enum
34from functools import total_ordering, wraps
35from types import FunctionType
36from typing import TYPE_CHECKING, Any, ClassVar, Final, NoReturn, TypeVar, cast
37
38import attrs
39
40from airflow.sdk import TriggerRule, timezone
41from airflow.sdk._shared.secrets_masker import redact
42from airflow.sdk.definitions._internal.abstractoperator import (
43 DEFAULT_EMAIL_ON_FAILURE,
44 DEFAULT_EMAIL_ON_RETRY,
45 DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST,
46 DEFAULT_OWNER,
47 DEFAULT_POOL_NAME,
48 DEFAULT_POOL_SLOTS,
49 DEFAULT_PRIORITY_WEIGHT,
50 DEFAULT_QUEUE,
51 DEFAULT_RETRIES,
52 DEFAULT_RETRY_DELAY,
53 DEFAULT_TASK_EXECUTION_TIMEOUT,
54 DEFAULT_TRIGGER_RULE,
55 DEFAULT_WAIT_FOR_PAST_DEPENDS_BEFORE_SKIPPING,
56 DEFAULT_WEIGHT_RULE,
57 AbstractOperator,
58 DependencyMixin,
59 TaskStateChangeCallback,
60)
61from airflow.sdk.definitions._internal.decorators import fixup_decorator_warning_stack
62from airflow.sdk.definitions._internal.node import validate_key
63from airflow.sdk.definitions._internal.setup_teardown import SetupTeardownContext
64from airflow.sdk.definitions._internal.types import NOTSET, validate_instance_args
65from airflow.sdk.definitions.edges import EdgeModifier
66from airflow.sdk.definitions.mappedoperator import OperatorPartial, validate_mapping_kwargs
67from airflow.sdk.definitions.param import ParamsDict
68from airflow.sdk.exceptions import RemovedInAirflow4Warning
69
70# Databases do not support arbitrary precision integers, so we need to limit the range of priority weights.
71# postgres: -2147483648 to +2147483647 (see https://www.postgresql.org/docs/current/datatype-numeric.html)
72# mysql: -2147483648 to +2147483647 (see https://dev.mysql.com/doc/refman/8.4/en/integer-types.html)
73# sqlite: -9223372036854775808 to +9223372036854775807 (see https://sqlite.org/datatype3.html)
74DB_SAFE_MINIMUM = -2147483648
75DB_SAFE_MAXIMUM = 2147483647
76
77
78def db_safe_priority(priority_weight: int) -> int:
79 """Convert priority weight to a safe value for the database."""
80 return max(DB_SAFE_MINIMUM, min(DB_SAFE_MAXIMUM, priority_weight))
81
82
83C = TypeVar("C", bound=Callable)
84T = TypeVar("T", bound=FunctionType)
85
86if TYPE_CHECKING:
87 from types import ClassMethodDescriptorType
88
89 import jinja2
90 from typing_extensions import Self
91
92 from airflow.sdk.api.datamodels._generated import DagAttributeTypes
93 from airflow.sdk.bases.operatorlink import BaseOperatorLink
94 from airflow.sdk.definitions.context import Context
95 from airflow.sdk.definitions.dag import DAG
96 from airflow.sdk.definitions.operator_resources import Resources
97 from airflow.sdk.definitions.taskgroup import TaskGroup
98 from airflow.sdk.definitions.xcom_arg import XComArg
99 from airflow.task.priority_strategy import PriorityWeightStrategy
100 from airflow.triggers.base import BaseTrigger, StartTriggerArgs
101
102 TaskPreExecuteHook = Callable[[Context], None]
103 TaskPostExecuteHook = Callable[[Context, Any], None]
104
105__all__ = [
106 "BaseAsyncOperator",
107 "BaseOperator",
108 "chain",
109 "chain_linear",
110 "cross_downstream",
111]
112
113
114class TriggerFailureReason(str, Enum):
115 """
116 Reasons for trigger failures.
117
118 Internal use only.
119
120 :meta private:
121 """
122
123 TRIGGER_TIMEOUT = "Trigger timeout"
124 TRIGGER_FAILURE = "Trigger failure"
125
126
127TRIGGER_FAIL_REPR = "__fail__"
128"""String value to represent trigger failure.
129
130Internal use only.
131
132:meta private:
133"""
134
135
136def _get_parent_defaults(dag: DAG | None, task_group: TaskGroup | None) -> tuple[dict, ParamsDict]:
137 if not dag:
138 return {}, ParamsDict()
139 dag_args = copy.copy(dag.default_args)
140 dag_params = copy.deepcopy(dag.params)
141 dag_params._fill_missing_param_source("dag")
142 if task_group:
143 if task_group.default_args and not isinstance(task_group.default_args, collections.abc.Mapping):
144 raise TypeError("default_args must be a mapping")
145 dag_args.update(task_group.default_args)
146 return dag_args, dag_params
147
148
149def get_merged_defaults(
150 dag: DAG | None,
151 task_group: TaskGroup | None,
152 task_params: collections.abc.MutableMapping | None,
153 task_default_args: dict | None,
154) -> tuple[dict, ParamsDict]:
155 args, params = _get_parent_defaults(dag, task_group)
156 if task_params:
157 if not isinstance(task_params, collections.abc.Mapping):
158 raise TypeError(f"params must be a mapping, got {type(task_params)}")
159
160 task_params = ParamsDict(task_params)
161 task_params._fill_missing_param_source("task")
162 params.update(task_params)
163
164 if task_default_args:
165 if not isinstance(task_default_args, collections.abc.Mapping):
166 raise TypeError(f"default_args must be a mapping, got {type(task_params)}")
167 args.update(task_default_args)
168 with contextlib.suppress(KeyError):
169 if params_from_default_args := ParamsDict(task_default_args["params"] or {}):
170 params_from_default_args._fill_missing_param_source("task")
171 params.update(params_from_default_args)
172
173 return args, params
174
175
176def parse_retries(retries: Any) -> int | None:
177 if retries is None:
178 return 0
179 if type(retries) == int: # noqa: E721
180 return retries
181 try:
182 parsed_retries = int(retries)
183 except (TypeError, ValueError):
184 raise RuntimeError(f"'retries' type must be int, not {type(retries).__name__}")
185 return parsed_retries
186
187
188def coerce_timedelta(value: float | timedelta, *, key: str | None = None) -> timedelta:
189 if isinstance(value, timedelta):
190 return value
191 return timedelta(seconds=value)
192
193
194def coerce_resources(resources: dict[str, Any] | None) -> Resources | None:
195 if resources is None:
196 return None
197 from airflow.sdk.definitions.operator_resources import Resources
198
199 return Resources(**resources)
200
201
202@contextlib.contextmanager
203def event_loop() -> Generator[AbstractEventLoop]:
204 new_event_loop = False
205 loop = None
206 try:
207 try:
208 loop = asyncio.get_event_loop()
209 if loop.is_closed():
210 raise RuntimeError
211 except RuntimeError:
212 loop = asyncio.new_event_loop()
213 asyncio.set_event_loop(loop)
214 new_event_loop = True
215 yield loop
216 finally:
217 if new_event_loop and loop is not None:
218 with contextlib.suppress(AttributeError):
219 loop.close()
220 asyncio.set_event_loop(None)
221
222
223class _PartialDescriptor:
224 """A descriptor that guards against ``.partial`` being called on Task objects."""
225
226 class_method: ClassMethodDescriptorType | None = None
227
228 def __get__(
229 self, obj: BaseOperator, cls: type[BaseOperator] | None = None
230 ) -> Callable[..., OperatorPartial]:
231 # Call this "partial" so it looks nicer in stack traces.
232 def partial(**kwargs):
233 raise TypeError("partial can only be called on Operator classes, not Tasks themselves")
234
235 if obj is not None:
236 return partial
237 return self.class_method.__get__(cls, cls)
238
239
240OPERATOR_DEFAULTS: dict[str, Any] = {
241 "allow_nested_operators": True,
242 "depends_on_past": False,
243 "email_on_failure": DEFAULT_EMAIL_ON_FAILURE,
244 "email_on_retry": DEFAULT_EMAIL_ON_RETRY,
245 "execution_timeout": DEFAULT_TASK_EXECUTION_TIMEOUT,
246 # "executor": DEFAULT_EXECUTOR,
247 "executor_config": {},
248 "ignore_first_depends_on_past": DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST,
249 "inlets": [],
250 "map_index_template": None,
251 "on_execute_callback": [],
252 "on_failure_callback": [],
253 "on_retry_callback": [],
254 "on_skipped_callback": [],
255 "on_success_callback": [],
256 "outlets": [],
257 "owner": DEFAULT_OWNER,
258 "pool_slots": DEFAULT_POOL_SLOTS,
259 "priority_weight": DEFAULT_PRIORITY_WEIGHT,
260 "queue": DEFAULT_QUEUE,
261 "retries": DEFAULT_RETRIES,
262 "retry_delay": DEFAULT_RETRY_DELAY,
263 "retry_exponential_backoff": 0,
264 "trigger_rule": DEFAULT_TRIGGER_RULE,
265 "wait_for_past_depends_before_skipping": DEFAULT_WAIT_FOR_PAST_DEPENDS_BEFORE_SKIPPING,
266 "wait_for_downstream": False,
267 "weight_rule": DEFAULT_WEIGHT_RULE,
268}
269
270
271# This is what handles the actual mapping.
272
273if TYPE_CHECKING:
274
275 def partial(
276 operator_class: type[BaseOperator],
277 *,
278 task_id: str,
279 dag: DAG | None = None,
280 task_group: TaskGroup | None = None,
281 start_date: datetime = ...,
282 end_date: datetime = ...,
283 owner: str = ...,
284 email: None | str | Iterable[str] = ...,
285 params: collections.abc.MutableMapping | None = None,
286 resources: dict[str, Any] | None = ...,
287 trigger_rule: str = ...,
288 depends_on_past: bool = ...,
289 ignore_first_depends_on_past: bool = ...,
290 wait_for_past_depends_before_skipping: bool = ...,
291 wait_for_downstream: bool = ...,
292 retries: int | None = ...,
293 queue: str = ...,
294 pool: str = ...,
295 pool_slots: int = ...,
296 execution_timeout: timedelta | None = ...,
297 max_retry_delay: None | timedelta | float = ...,
298 retry_delay: timedelta | float = ...,
299 retry_exponential_backoff: float = ...,
300 priority_weight: int = ...,
301 weight_rule: str | PriorityWeightStrategy = ...,
302 sla: timedelta | None = ...,
303 map_index_template: str | None = ...,
304 max_active_tis_per_dag: int | None = ...,
305 max_active_tis_per_dagrun: int | None = ...,
306 on_execute_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = ...,
307 on_failure_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = ...,
308 on_success_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = ...,
309 on_retry_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = ...,
310 on_skipped_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = ...,
311 run_as_user: str | None = ...,
312 executor: str | None = ...,
313 executor_config: dict | None = ...,
314 inlets: Any | None = ...,
315 outlets: Any | None = ...,
316 doc: str | None = ...,
317 doc_md: str | None = ...,
318 doc_json: str | None = ...,
319 doc_yaml: str | None = ...,
320 doc_rst: str | None = ...,
321 task_display_name: str | None = ...,
322 logger_name: str | None = ...,
323 allow_nested_operators: bool = True,
324 **kwargs,
325 ) -> OperatorPartial: ...
326else:
327
328 def partial(
329 operator_class: type[BaseOperator],
330 *,
331 task_id: str,
332 dag: DAG | None = None,
333 task_group: TaskGroup | None = None,
334 params: collections.abc.MutableMapping | None = None,
335 **kwargs,
336 ):
337 from airflow.sdk.definitions._internal.contextmanager import DagContext, TaskGroupContext
338
339 validate_mapping_kwargs(operator_class, "partial", kwargs)
340
341 dag = dag or DagContext.get_current()
342 if dag:
343 task_group = task_group or TaskGroupContext.get_current(dag)
344 if task_group:
345 task_id = task_group.child_id(task_id)
346
347 # Merge Dag and task group level defaults into user-supplied values.
348 dag_default_args, partial_params = get_merged_defaults(
349 dag=dag,
350 task_group=task_group,
351 task_params=params,
352 task_default_args=kwargs.pop("default_args", None),
353 )
354
355 # Create partial_kwargs from args and kwargs
356 partial_kwargs: dict[str, Any] = {
357 "task_id": task_id,
358 "dag": dag,
359 "task_group": task_group,
360 **kwargs,
361 }
362
363 # Inject Dag-level default args into args provided to this function.
364 # Most of the default args will be retrieved during unmapping; here we
365 # only ensure base properties are correctly set for the scheduler.
366 partial_kwargs.update(
367 (k, v)
368 for k, v in dag_default_args.items()
369 if k not in partial_kwargs and k in BaseOperator.__init__._BaseOperatorMeta__param_names
370 )
371
372 # Fill fields not provided by the user with default values.
373 partial_kwargs.update((k, v) for k, v in OPERATOR_DEFAULTS.items() if k not in partial_kwargs)
374
375 # Post-process arguments. Should be kept in sync with _TaskDecorator.expand().
376 if "task_concurrency" in kwargs: # Reject deprecated option.
377 raise TypeError("unexpected argument: task_concurrency")
378 if start_date := partial_kwargs.get("start_date", None):
379 partial_kwargs["start_date"] = timezone.convert_to_utc(start_date)
380 if end_date := partial_kwargs.get("end_date", None):
381 partial_kwargs["end_date"] = timezone.convert_to_utc(end_date)
382 if partial_kwargs["pool_slots"] < 1:
383 dag_str = ""
384 if dag:
385 dag_str = f" in dag {dag.dag_id}"
386 raise ValueError(f"pool slots for {task_id}{dag_str} cannot be less than 1")
387 if retries := partial_kwargs.get("retries"):
388 partial_kwargs["retries"] = BaseOperator._convert_retries(retries)
389 partial_kwargs["retry_delay"] = BaseOperator._convert_retry_delay(partial_kwargs["retry_delay"])
390 partial_kwargs["max_retry_delay"] = BaseOperator._convert_max_retry_delay(
391 partial_kwargs.get("max_retry_delay", None)
392 )
393
394 for k in ("execute", "failure", "success", "retry", "skipped"):
395 partial_kwargs[attr] = _collect_from_input(partial_kwargs.get(attr := f"on_{k}_callback"))
396
397 return OperatorPartial(
398 operator_class=operator_class,
399 kwargs=partial_kwargs,
400 params=partial_params,
401 )
402
403
404class ExecutorSafeguard:
405 """
406 The ExecutorSafeguard decorator.
407
408 Checks if the execute method of an operator isn't manually called outside
409 the TaskInstance as we want to avoid bad mixing between decorated and
410 classic operators.
411 """
412
413 test_mode: ClassVar[bool] = False
414 tracker: ClassVar[ContextVar[BaseOperator]] = ContextVar("ExecutorSafeguard_sentinel")
415 sentinel_value: ClassVar[object] = object()
416
417 @classmethod
418 def decorator(cls, func):
419 @wraps(func)
420 def wrapper(self, *args, **kwargs):
421 sentinel_key = f"{self.__class__.__name__}__sentinel"
422 sentinel = kwargs.pop(sentinel_key, None)
423
424 with contextlib.ExitStack() as stack:
425 if sentinel is cls.sentinel_value:
426 token = cls.tracker.set(self)
427 sentinel = self
428 stack.callback(cls.tracker.reset, token)
429 else:
430 # No sentinel passed in, maybe the subclass execute did have it passed?
431 sentinel = cls.tracker.get(None)
432
433 if not cls.test_mode and sentinel is not self:
434 message = f"{self.__class__.__name__}.{func.__name__} cannot be called outside of the Task Runner!"
435 if not self.allow_nested_operators:
436 raise RuntimeError(message)
437 self.log.warning(message)
438
439 # Now that we've logged, set sentinel so that `super()` calls don't log again
440 token = cls.tracker.set(self)
441 stack.callback(cls.tracker.reset, token)
442
443 return func(self, *args, **kwargs)
444
445 return wrapper
446
447
448if "airflow.configuration" in sys.modules:
449 # Don't try and import it if it's not already loaded
450 from airflow.sdk.configuration import conf
451
452 ExecutorSafeguard.test_mode = conf.getboolean("core", "unit_test_mode")
453
454
455def _collect_from_input(value_or_values: None | C | Collection[C]) -> list[C]:
456 if not value_or_values:
457 return []
458 if isinstance(value_or_values, Collection):
459 return list(value_or_values)
460 return [value_or_values]
461
462
463class BaseOperatorMeta(abc.ABCMeta):
464 """Metaclass of BaseOperator."""
465
466 @classmethod
467 def _apply_defaults(cls, func: T) -> T:
468 """
469 Look for an argument named "default_args", and fill the unspecified arguments from it.
470
471 Since python2.* isn't clear about which arguments are missing when
472 calling a function, and that this can be quite confusing with multi-level
473 inheritance and argument defaults, this decorator also alerts with
474 specific information about the missing arguments.
475 """
476 # Cache inspect.signature for the wrapper closure to avoid calling it
477 # at every decorated invocation. This is separate sig_cache created
478 # per decoration, i.e. each function decorated using apply_defaults will
479 # have a different sig_cache.
480 sig_cache = inspect.signature(func)
481 non_variadic_params = {
482 name: param
483 for (name, param) in sig_cache.parameters.items()
484 if param.name != "self" and param.kind not in (param.VAR_POSITIONAL, param.VAR_KEYWORD)
485 }
486 non_optional_args = {
487 name
488 for name, param in non_variadic_params.items()
489 if param.default == param.empty and name != "task_id"
490 }
491
492 fixup_decorator_warning_stack(func)
493
494 @wraps(func)
495 def apply_defaults(self: BaseOperator, *args: Any, **kwargs: Any) -> Any:
496 from airflow.sdk.definitions._internal.contextmanager import DagContext, TaskGroupContext
497
498 if args:
499 raise TypeError("Use keyword arguments when initializing operators")
500
501 instantiated_from_mapped = kwargs.pop(
502 "_airflow_from_mapped",
503 getattr(self, "_BaseOperator__from_mapped", False),
504 )
505
506 dag: DAG | None = kwargs.get("dag")
507 if dag is None:
508 dag = DagContext.get_current()
509 if dag is not None:
510 kwargs["dag"] = dag
511
512 task_group: TaskGroup | None = kwargs.get("task_group")
513 if dag and not task_group:
514 task_group = TaskGroupContext.get_current(dag)
515 if task_group is not None:
516 kwargs["task_group"] = task_group
517
518 default_args, merged_params = get_merged_defaults(
519 dag=dag,
520 task_group=task_group,
521 task_params=kwargs.pop("params", None),
522 task_default_args=kwargs.pop("default_args", None),
523 )
524
525 for arg in sig_cache.parameters:
526 if arg not in kwargs and arg in default_args:
527 kwargs[arg] = default_args[arg]
528
529 missing_args = non_optional_args.difference(kwargs)
530 if len(missing_args) == 1:
531 raise TypeError(f"missing keyword argument {missing_args.pop()!r}")
532 if missing_args:
533 display = ", ".join(repr(a) for a in sorted(missing_args))
534 raise TypeError(f"missing keyword arguments {display}")
535
536 if merged_params:
537 kwargs["params"] = merged_params
538
539 hook = getattr(self, "_hook_apply_defaults", None)
540 if hook:
541 args, kwargs = hook(**kwargs, default_args=default_args)
542 default_args = kwargs.pop("default_args", {})
543
544 if not hasattr(self, "_BaseOperator__init_kwargs"):
545 object.__setattr__(self, "_BaseOperator__init_kwargs", {})
546 object.__setattr__(self, "_BaseOperator__from_mapped", instantiated_from_mapped)
547
548 result = func(self, **kwargs, default_args=default_args)
549
550 # Store the args passed to init -- we need them to support task.map serialization!
551 self._BaseOperator__init_kwargs.update(kwargs) # type: ignore
552
553 # Set upstream task defined by XComArgs passed to template fields of the operator.
554 # BUT: only do this _ONCE_, not once for each class in the hierarchy
555 if not instantiated_from_mapped and func == self.__init__.__wrapped__: # type: ignore[misc]
556 self._set_xcomargs_dependencies()
557 # Mark instance as instantiated so that future attr setting updates xcomarg-based deps.
558 object.__setattr__(self, "_BaseOperator__instantiated", True)
559
560 return result
561
562 apply_defaults.__non_optional_args = non_optional_args # type: ignore
563 apply_defaults.__param_names = set(non_variadic_params) # type: ignore
564
565 return cast("T", apply_defaults)
566
567 def __new__(cls, name, bases, namespace, **kwargs):
568 execute_method = namespace.get("execute")
569 if callable(execute_method) and not getattr(execute_method, "__isabstractmethod__", False):
570 namespace["execute"] = ExecutorSafeguard.decorator(execute_method)
571 new_cls = super().__new__(cls, name, bases, namespace, **kwargs)
572 with contextlib.suppress(KeyError):
573 # Update the partial descriptor with the class method, so it calls the actual function
574 # (but let subclasses override it if they need to)
575 partial_desc = vars(new_cls)["partial"]
576 if isinstance(partial_desc, _PartialDescriptor):
577 partial_desc.class_method = classmethod(partial)
578
579 # We patch `__init__` only if the class defines it.
580 first_superclass = new_cls.mro()[1]
581 if new_cls.__init__ is not first_superclass.__init__:
582 new_cls.__init__ = cls._apply_defaults(new_cls.__init__)
583
584 return new_cls
585
586
587# TODO: The following mapping is used to validate that the arguments passed to the BaseOperator are of the
588# correct type. This is a temporary solution until we find a more sophisticated method for argument
589# validation. One potential method is to use `get_type_hints` from the typing module. However, this is not
590# fully compatible with future annotations for Python versions below 3.10. Once we require a minimum Python
591# version that supports `get_type_hints` effectively or find a better approach, we can replace this
592# manual type-checking method.
593BASEOPERATOR_ARGS_EXPECTED_TYPES = {
594 "task_id": str,
595 "email": (str, Sequence),
596 "email_on_retry": bool,
597 "email_on_failure": bool,
598 "retries": int,
599 "retry_exponential_backoff": (int, float),
600 "depends_on_past": bool,
601 "ignore_first_depends_on_past": bool,
602 "wait_for_past_depends_before_skipping": bool,
603 "wait_for_downstream": bool,
604 "priority_weight": int,
605 "queue": str,
606 "pool": str,
607 "pool_slots": int,
608 "trigger_rule": str,
609 "run_as_user": str,
610 "task_concurrency": int,
611 "map_index_template": str,
612 "max_active_tis_per_dag": int,
613 "max_active_tis_per_dagrun": int,
614 "executor": str,
615 "do_xcom_push": bool,
616 "multiple_outputs": bool,
617 "doc": str,
618 "doc_md": str,
619 "doc_json": str,
620 "doc_yaml": str,
621 "doc_rst": str,
622 "task_display_name": str,
623 "logger_name": str,
624 "allow_nested_operators": bool,
625 "start_date": datetime,
626 "end_date": datetime,
627}
628
629
630# Note: BaseOperator is defined as a dataclass, and not an attrs class as we do too much metaprogramming in
631# here (metaclass, custom `__setattr__` behaviour) and this fights with attrs too much to make it worth it.
632#
633# To future reader: if you want to try and make this a "normal" attrs class, go ahead and attempt it. If you
634# get nowhere leave your record here for the next poor soul and what problems you ran in to.
635#
636# @ashb, 2024/10/14
637# - "Can't combine custom __setattr__ with on_setattr hooks"
638# - Setting class-wide `define(on_setarrs=...)` isn't called for non-attrs subclasses
639@total_ordering
640@dataclass(repr=False)
641class BaseOperator(AbstractOperator, metaclass=BaseOperatorMeta):
642 r"""
643 Abstract base class for all operators.
644
645 Since operators create objects that become nodes in the Dag, BaseOperator
646 contains many recursive methods for Dag crawling behavior. To derive from
647 this class, you are expected to override the constructor and the 'execute'
648 method.
649
650 Operators derived from this class should perform or trigger certain tasks
651 synchronously (wait for completion). Example of operators could be an
652 operator that runs a Pig job (PigOperator), a sensor operator that
653 waits for a partition to land in Hive (HiveSensorOperator), or one that
654 moves data from Hive to MySQL (Hive2MySqlOperator). Instances of these
655 operators (tasks) target specific operations, running specific scripts,
656 functions or data transfers.
657
658 This class is abstract and shouldn't be instantiated. Instantiating a
659 class derived from this one results in the creation of a task object,
660 which ultimately becomes a node in Dag objects. Task dependencies should
661 be set by using the set_upstream and/or set_downstream methods.
662
663 :param task_id: a unique, meaningful id for the task
664 :param owner: the owner of the task. Using a meaningful description
665 (e.g. user/person/team/role name) to clarify ownership is recommended.
666 :param email: the 'to' email address(es) used in email alerts. This can be a
667 single email or multiple ones. Multiple addresses can be specified as a
668 comma or semicolon separated string or by passing a list of strings. (deprecated)
669 :param email_on_retry: Indicates whether email alerts should be sent when a
670 task is retried (deprecated)
671 :param email_on_failure: Indicates whether email alerts should be sent when
672 a task failed (deprecated)
673 :param retries: the number of retries that should be performed before
674 failing the task
675 :param retry_delay: delay between retries, can be set as ``timedelta`` or
676 ``float`` seconds, which will be converted into ``timedelta``,
677 the default is ``timedelta(seconds=300)``.
678 :param retry_exponential_backoff: multiplier for exponential backoff between retries.
679 Set to 0 to disable (constant delay). Set to 2.0 for standard exponential backoff
680 (delay doubles with each retry). For example, with retry_delay=4min and
681 retry_exponential_backoff=5, retries occur after 4min, 20min, 100min, etc.
682 :param max_retry_delay: maximum delay interval between retries, can be set as
683 ``timedelta`` or ``float`` seconds, which will be converted into ``timedelta``.
684 :param start_date: The ``start_date`` for the task, determines
685 the ``logical_date`` for the first task instance. The best practice
686 is to have the start_date rounded
687 to your Dag's ``schedule_interval``. Daily jobs have their start_date
688 some day at 00:00:00, hourly jobs have their start_date at 00:00
689 of a specific hour. Note that Airflow simply looks at the latest
690 ``logical_date`` and adds the ``schedule_interval`` to determine
691 the next ``logical_date``. It is also very important
692 to note that different tasks' dependencies
693 need to line up in time. If task A depends on task B and their
694 start_date are offset in a way that their logical_date don't line
695 up, A's dependencies will never be met. If you are looking to delay
696 a task, for example running a daily task at 2AM, look into the
697 ``TimeSensor`` and ``TimeDeltaSensor``. We advise against using
698 dynamic ``start_date`` and recommend using fixed ones. Read the
699 FAQ entry about start_date for more information.
700 :param end_date: if specified, the scheduler won't go beyond this date
701 :param depends_on_past: when set to true, task instances will run
702 sequentially and only if the previous instance has succeeded or has been skipped.
703 The task instance for the start_date is allowed to run.
704 :param wait_for_past_depends_before_skipping: when set to true, if the task instance
705 should be marked as skipped, and depends_on_past is true, the ti will stay on None state
706 waiting the task of the previous run
707 :param wait_for_downstream: when set to true, an instance of task
708 X will wait for tasks immediately downstream of the previous instance
709 of task X to finish successfully or be skipped before it runs. This is useful if the
710 different instances of a task X alter the same asset, and this asset
711 is used by tasks downstream of task X. Note that depends_on_past
712 is forced to True wherever wait_for_downstream is used. Also note that
713 only tasks *immediately* downstream of the previous task instance are waited
714 for; the statuses of any tasks further downstream are ignored.
715 :param dag: a reference to the dag the task is attached to (if any)
716 :param priority_weight: priority weight of this task against other task.
717 This allows the executor to trigger higher priority tasks before
718 others when things get backed up. Set priority_weight as a higher
719 number for more important tasks.
720 As not all database engines support 64-bit integers, values are capped with 32-bit.
721 Valid range is from -2,147,483,648 to 2,147,483,647.
722 :param weight_rule: weighting method used for the effective total
723 priority weight of the task. Options are:
724 ``{ downstream | upstream | absolute }`` default is ``downstream``
725 When set to ``downstream`` the effective weight of the task is the
726 aggregate sum of all downstream descendants. As a result, upstream
727 tasks will have higher weight and will be scheduled more aggressively
728 when using positive weight values. This is useful when you have
729 multiple dag run instances and desire to have all upstream tasks to
730 complete for all runs before each dag can continue processing
731 downstream tasks. When set to ``upstream`` the effective weight is the
732 aggregate sum of all upstream ancestors. This is the opposite where
733 downstream tasks have higher weight and will be scheduled more
734 aggressively when using positive weight values. This is useful when you
735 have multiple dag run instances and prefer to have each dag complete
736 before starting upstream tasks of other dags. When set to
737 ``absolute``, the effective weight is the exact ``priority_weight``
738 specified without additional weighting. You may want to do this when
739 you know exactly what priority weight each task should have.
740 Additionally, when set to ``absolute``, there is bonus effect of
741 significantly speeding up the task creation process as for very large
742 Dags. Options can be set as string or using the constants defined in
743 the static class ``airflow.utils.WeightRule``.
744 Irrespective of the weight rule, resulting priority values are capped with 32-bit.
745 |experimental|
746 Since 2.9.0, Airflow allows to define custom priority weight strategy,
747 by creating a subclass of
748 ``airflow.task.priority_strategy.PriorityWeightStrategy`` and registering
749 in a plugin, then providing the class path or the class instance via
750 ``weight_rule`` parameter. The custom priority weight strategy will be
751 used to calculate the effective total priority weight of the task instance.
752 :param queue: which queue to target when running this job. Not
753 all executors implement queue management, the CeleryExecutor
754 does support targeting specific queues.
755 :param pool: the slot pool this task should run in, slot pools are a
756 way to limit concurrency for certain tasks
757 :param pool_slots: the number of pool slots this task should use (>= 1)
758 Values less than 1 are not allowed.
759 :param sla: DEPRECATED - The SLA feature is removed in Airflow 3.0, to be replaced with a
760 new implementation in Airflow >=3.1.
761 :param execution_timeout: max time allowed for the execution of
762 this task instance, if it goes beyond it will raise and fail.
763 :param on_failure_callback: a function or list of functions to be called when a task instance
764 of this task fails. a context dictionary is passed as a single
765 parameter to this function. Context contains references to related
766 objects to the task instance and is documented under the macros
767 section of the API.
768 :param on_execute_callback: much like the ``on_failure_callback`` except
769 that it is executed right before the task is executed.
770 :param on_retry_callback: much like the ``on_failure_callback`` except
771 that it is executed when retries occur.
772 :param on_success_callback: much like the ``on_failure_callback`` except
773 that it is executed when the task succeeds.
774 :param on_skipped_callback: much like the ``on_failure_callback`` except
775 that it is executed when skipped occur; this callback will be called only if AirflowSkipException get raised.
776 Explicitly it is NOT called if a task is not started to be executed because of a preceding branching
777 decision in the Dag or a trigger rule which causes execution to skip so that the task execution
778 is never scheduled.
779 :param pre_execute: a function to be called immediately before task
780 execution, receiving a context dictionary; raising an exception will
781 prevent the task from being executed.
782 :param post_execute: a function to be called immediately after task
783 execution, receiving a context dictionary and task result; raising an
784 exception will prevent the task from succeeding.
785 :param trigger_rule: defines the rule by which dependencies are applied
786 for the task to get triggered. Options are:
787 ``{ all_success | all_failed | all_done | all_skipped | one_success | one_done |
788 one_failed | none_failed | none_failed_min_one_success | none_skipped | always}``
789 default is ``all_success``. Options can be set as string or
790 using the constants defined in the static class
791 ``airflow.utils.TriggerRule``
792 :param resources: A map of resource parameter names (the argument names of the
793 Resources constructor) to their values.
794 :param run_as_user: unix username to impersonate while running the task
795 :param max_active_tis_per_dag: When set, a task will be able to limit the concurrent
796 runs across logical_dates.
797 :param max_active_tis_per_dagrun: When set, a task will be able to limit the concurrent
798 task instances per Dag run.
799 :param executor: Which executor to target when running this task. NOT YET SUPPORTED
800 :param executor_config: Additional task-level configuration parameters that are
801 interpreted by a specific executor. Parameters are namespaced by the name of
802 executor.
803
804 **Example**: to run this task in a specific docker container through
805 the KubernetesExecutor ::
806
807 MyOperator(..., executor_config={"KubernetesExecutor": {"image": "myCustomDockerImage"}})
808
809 :param do_xcom_push: if True, an XCom is pushed containing the Operator's
810 result
811 :param multiple_outputs: if True and do_xcom_push is True, pushes multiple XComs, one for each
812 key in the returned dictionary result. If False and do_xcom_push is True, pushes a single XCom.
813 :param task_group: The TaskGroup to which the task should belong. This is typically provided when not
814 using a TaskGroup as a context manager.
815 :param doc: Add documentation or notes to your Task objects that is visible in
816 Task Instance details View in the Webserver
817 :param doc_md: Add documentation (in Markdown format) or notes to your Task objects
818 that is visible in Task Instance details View in the Webserver
819 :param doc_rst: Add documentation (in RST format) or notes to your Task objects
820 that is visible in Task Instance details View in the Webserver
821 :param doc_json: Add documentation (in JSON format) or notes to your Task objects
822 that is visible in Task Instance details View in the Webserver
823 :param doc_yaml: Add documentation (in YAML format) or notes to your Task objects
824 that is visible in Task Instance details View in the Webserver
825 :param task_display_name: The display name of the task which appears on the UI.
826 :param logger_name: Name of the logger used by the Operator to emit logs.
827 If set to `None` (default), the logger name will fall back to
828 `airflow.task.operators.{class.__module__}.{class.__name__}` (e.g. HttpOperator will have
829 *airflow.task.operators.airflow.providers.http.operators.http.HttpOperator* as logger).
830 :param allow_nested_operators: if True, when an operator is executed within another one a warning message
831 will be logged. If False, then an exception will be raised if the operator is badly used (e.g. nested
832 within another one). In future releases of Airflow this parameter will be removed and an exception
833 will always be thrown when operators are nested within each other (default is True).
834
835 **Example**: example of a bad operator mixin usage::
836
837 @task(provide_context=True)
838 def say_hello_world(**context):
839 hello_world_task = BashOperator(
840 task_id="hello_world_task",
841 bash_command="python -c \"print('Hello, world!')\"",
842 dag=dag,
843 )
844 hello_world_task.execute(context)
845 :param render_template_as_native_obj: If True, uses a Jinja ``NativeEnvironment``
846 to render templates as native Python types. If False, a Jinja
847 ``Environment`` is used to render templates as string values.
848 If None (default), inherits from the DAG setting.
849 """
850
851 task_id: str
852 owner: str = DEFAULT_OWNER
853 email: str | Sequence[str] | None = None
854 email_on_retry: bool = DEFAULT_EMAIL_ON_RETRY
855 email_on_failure: bool = DEFAULT_EMAIL_ON_FAILURE
856 retries: int | None = DEFAULT_RETRIES
857 retry_delay: timedelta = DEFAULT_RETRY_DELAY
858 retry_exponential_backoff: float = 0
859 max_retry_delay: timedelta | float | None = None
860 start_date: datetime | None = None
861 end_date: datetime | None = None
862 depends_on_past: bool = False
863 ignore_first_depends_on_past: bool = DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST
864 wait_for_past_depends_before_skipping: bool = DEFAULT_WAIT_FOR_PAST_DEPENDS_BEFORE_SKIPPING
865 wait_for_downstream: bool = False
866
867 # At execution_time this becomes a normal dict
868 params: ParamsDict | dict = field(default_factory=ParamsDict)
869 default_args: dict | None = None
870 priority_weight: int = DEFAULT_PRIORITY_WEIGHT
871 weight_rule: PriorityWeightStrategy | str = field(default=DEFAULT_WEIGHT_RULE)
872 queue: str = DEFAULT_QUEUE
873 pool: str = DEFAULT_POOL_NAME
874 pool_slots: int = DEFAULT_POOL_SLOTS
875 execution_timeout: timedelta | None = DEFAULT_TASK_EXECUTION_TIMEOUT
876 on_execute_callback: Sequence[TaskStateChangeCallback] = ()
877 on_failure_callback: Sequence[TaskStateChangeCallback] = ()
878 on_success_callback: Sequence[TaskStateChangeCallback] = ()
879 on_retry_callback: Sequence[TaskStateChangeCallback] = ()
880 on_skipped_callback: Sequence[TaskStateChangeCallback] = ()
881 _pre_execute_hook: TaskPreExecuteHook | None = None
882 _post_execute_hook: TaskPostExecuteHook | None = None
883 trigger_rule: TriggerRule = DEFAULT_TRIGGER_RULE
884 resources: dict[str, Any] | None = None
885 run_as_user: str | None = None
886 task_concurrency: int | None = None
887 map_index_template: str | None = None
888 max_active_tis_per_dag: int | None = None
889 max_active_tis_per_dagrun: int | None = None
890 executor: str | None = None
891 executor_config: dict | None = None
892 do_xcom_push: bool = True
893 multiple_outputs: bool = False
894 inlets: list[Any] = field(default_factory=list)
895 outlets: list[Any] = field(default_factory=list)
896 task_group: TaskGroup | None = None
897 doc: str | None = None
898 doc_md: str | None = None
899 doc_json: str | None = None
900 doc_yaml: str | None = None
901 doc_rst: str | None = None
902 _task_display_name: str | None = None
903 logger_name: str | None = None
904 allow_nested_operators: bool = True
905 render_template_as_native_obj: bool | None = None
906
907 is_setup: bool = False
908 is_teardown: bool = False
909
910 # TODO: Task-SDK: Make these ClassVar[]?
911 template_fields: Collection[str] = ()
912 template_ext: Sequence[str] = ()
913
914 template_fields_renderers: ClassVar[dict[str, str]] = {}
915
916 operator_extra_links: Collection[BaseOperatorLink] = ()
917
918 # Defines the color in the UI
919 ui_color: str = "#fff"
920 ui_fgcolor: str = "#000"
921
922 partial: Callable[..., OperatorPartial] = _PartialDescriptor() # type: ignore
923
924 _dag: DAG | None = field(init=False, default=None)
925
926 # Make this optional so the type matches the one define in LoggingMixin
927 _log_config_logger_name: str | None = field(default="airflow.task.operators", init=False)
928 _logger_name: str | None = None
929
930 # The _serialized_fields are lazily loaded when get_serialized_fields() method is called
931 __serialized_fields: ClassVar[frozenset[str] | None] = None
932
933 _comps: ClassVar[set[str]] = {
934 "task_id",
935 "dag_id",
936 "owner",
937 "email",
938 "email_on_retry",
939 "retry_delay",
940 "retry_exponential_backoff",
941 "max_retry_delay",
942 "start_date",
943 "end_date",
944 "depends_on_past",
945 "wait_for_downstream",
946 "priority_weight",
947 "execution_timeout",
948 "has_on_execute_callback",
949 "has_on_failure_callback",
950 "has_on_success_callback",
951 "has_on_retry_callback",
952 "has_on_skipped_callback",
953 "do_xcom_push",
954 "multiple_outputs",
955 "allow_nested_operators",
956 "executor",
957 }
958
959 # If True, the Rendered Template fields will be overwritten in DB after execution
960 # This is useful for Taskflow decorators that modify the template fields during execution like
961 # @task.bash decorator.
962 overwrite_rtif_after_execution: bool = False
963
964 # If True then the class constructor was called
965 __instantiated: bool = False
966 # List of args as passed to `init()`, after apply_defaults() has been updated. Used to "recreate" the task
967 # when mapping
968 # Set via the metaclass
969 __init_kwargs: dict[str, Any] = field(init=False)
970
971 # Set to True before calling execute method
972 _lock_for_execution: bool = False
973
974 # Set to True for an operator instantiated by a mapped operator.
975 __from_mapped: bool = False
976
977 start_trigger_args: StartTriggerArgs | None = None
978 start_from_trigger: bool = False
979
980 # base list which includes all the attrs that don't need deep copy.
981 _base_operator_shallow_copy_attrs: Final[tuple[str, ...]] = (
982 "user_defined_macros",
983 "user_defined_filters",
984 "params",
985 )
986
987 # each operator should override this class attr for shallow copy attrs.
988 shallow_copy_attrs: Sequence[str] = ()
989
990 def __setattr__(self: BaseOperator, key: str, value: Any):
991 if converter := getattr(self, f"_convert_{key}", None):
992 value = converter(value)
993 super().__setattr__(key, value)
994 if self.__from_mapped or self._lock_for_execution:
995 return # Skip any custom behavior for validation and during execute.
996 if key in self.__init_kwargs:
997 self.__init_kwargs[key] = value
998 if self.__instantiated and key in self.template_fields:
999 # Resolve upstreams set by assigning an XComArg after initializing
1000 # an operator, example:
1001 # op = BashOperator()
1002 # op.bash_command = "sleep 1"
1003 self._set_xcomargs_dependency(key, value)
1004
1005 def __init__(
1006 self,
1007 *,
1008 task_id: str,
1009 owner: str = DEFAULT_OWNER,
1010 email: str | Sequence[str] | None = None,
1011 email_on_retry: bool = DEFAULT_EMAIL_ON_RETRY,
1012 email_on_failure: bool = DEFAULT_EMAIL_ON_FAILURE,
1013 retries: int | None = DEFAULT_RETRIES,
1014 retry_delay: timedelta | float = DEFAULT_RETRY_DELAY,
1015 retry_exponential_backoff: float = 0,
1016 max_retry_delay: timedelta | float | None = None,
1017 start_date: datetime | None = None,
1018 end_date: datetime | None = None,
1019 depends_on_past: bool = False,
1020 ignore_first_depends_on_past: bool = DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST,
1021 wait_for_past_depends_before_skipping: bool = DEFAULT_WAIT_FOR_PAST_DEPENDS_BEFORE_SKIPPING,
1022 wait_for_downstream: bool = False,
1023 dag: DAG | None = None,
1024 params: collections.abc.MutableMapping[str, Any] | None = None,
1025 default_args: dict | None = None,
1026 priority_weight: int = DEFAULT_PRIORITY_WEIGHT,
1027 weight_rule: str | PriorityWeightStrategy = DEFAULT_WEIGHT_RULE,
1028 queue: str = DEFAULT_QUEUE,
1029 pool: str | None = None,
1030 pool_slots: int = DEFAULT_POOL_SLOTS,
1031 sla: timedelta | None = None,
1032 execution_timeout: timedelta | None = DEFAULT_TASK_EXECUTION_TIMEOUT,
1033 on_execute_callback: None | TaskStateChangeCallback | Collection[TaskStateChangeCallback] = None,
1034 on_failure_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = None,
1035 on_success_callback: None | TaskStateChangeCallback | Collection[TaskStateChangeCallback] = None,
1036 on_retry_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = None,
1037 on_skipped_callback: None | TaskStateChangeCallback | Collection[TaskStateChangeCallback] = None,
1038 pre_execute: TaskPreExecuteHook | None = None,
1039 post_execute: TaskPostExecuteHook | None = None,
1040 trigger_rule: str = DEFAULT_TRIGGER_RULE,
1041 resources: dict[str, Any] | None = None,
1042 run_as_user: str | None = None,
1043 map_index_template: str | None = None,
1044 max_active_tis_per_dag: int | None = None,
1045 max_active_tis_per_dagrun: int | None = None,
1046 executor: str | None = None,
1047 executor_config: dict | None = None,
1048 do_xcom_push: bool = True,
1049 multiple_outputs: bool = False,
1050 inlets: Any | None = None,
1051 outlets: Any | None = None,
1052 task_group: TaskGroup | None = None,
1053 doc: str | None = None,
1054 doc_md: str | None = None,
1055 doc_json: str | None = None,
1056 doc_yaml: str | None = None,
1057 doc_rst: str | None = None,
1058 task_display_name: str | None = None,
1059 logger_name: str | None = None,
1060 allow_nested_operators: bool = True,
1061 render_template_as_native_obj: bool | None = None,
1062 **kwargs: Any,
1063 ):
1064 # Note: Metaclass handles passing in the Dag/TaskGroup from active context manager, if any
1065
1066 # Only apply task_group prefix if this operator was not created from a mapped operator
1067 # Mapped operators already have the prefix applied during their creation
1068 if task_group and not self.__from_mapped:
1069 self.task_id = task_group.child_id(task_id)
1070 task_group.add(self)
1071 else:
1072 self.task_id = task_id
1073
1074 super().__init__()
1075 self.task_group = task_group
1076
1077 kwargs.pop("_airflow_mapped_validation_only", None)
1078 if kwargs:
1079 raise TypeError(
1080 f"Invalid arguments were passed to {self.__class__.__name__} (task_id: {task_id}). "
1081 f"Invalid arguments were:\n**kwargs: {redact(kwargs)}",
1082 )
1083 validate_key(self.task_id)
1084
1085 self.owner = owner
1086 self.email = email
1087 self.email_on_retry = email_on_retry
1088 self.email_on_failure = email_on_failure
1089
1090 if email is not None:
1091 warnings.warn(
1092 "Setting email on a task is deprecated; please migrate to SmtpNotifier.",
1093 RemovedInAirflow4Warning,
1094 stacklevel=2,
1095 )
1096 if email and email_on_retry is not None:
1097 warnings.warn(
1098 "Setting email_on_retry on a task is deprecated; please migrate to SmtpNotifier.",
1099 RemovedInAirflow4Warning,
1100 stacklevel=2,
1101 )
1102 if email and email_on_failure is not None:
1103 warnings.warn(
1104 "Setting email_on_failure on a task is deprecated; please migrate to SmtpNotifier.",
1105 RemovedInAirflow4Warning,
1106 stacklevel=2,
1107 )
1108
1109 if execution_timeout is not None and not isinstance(execution_timeout, timedelta):
1110 raise ValueError(
1111 f"execution_timeout must be timedelta object but passed as type: {type(execution_timeout)}"
1112 )
1113 self.execution_timeout = execution_timeout
1114
1115 self.on_execute_callback = _collect_from_input(on_execute_callback)
1116 self.on_failure_callback = _collect_from_input(on_failure_callback)
1117 self.on_success_callback = _collect_from_input(on_success_callback)
1118 self.on_retry_callback = _collect_from_input(on_retry_callback)
1119 self.on_skipped_callback = _collect_from_input(on_skipped_callback)
1120 self._pre_execute_hook = pre_execute
1121 self._post_execute_hook = post_execute
1122
1123 self.start_date = timezone.convert_to_utc(start_date)
1124 self.end_date = timezone.convert_to_utc(end_date)
1125 self.executor = executor
1126 self.executor_config = executor_config or {}
1127 self.run_as_user = run_as_user
1128 # TODO:
1129 # self.retries = parse_retries(retries)
1130 self.retries = retries
1131 self.queue = queue
1132 self.pool = DEFAULT_POOL_NAME if pool is None else pool
1133 self.pool_slots = pool_slots
1134 if self.pool_slots < 1:
1135 dag_str = f" in dag {dag.dag_id}" if dag else ""
1136 raise ValueError(f"pool slots for {self.task_id}{dag_str} cannot be less than 1")
1137 if sla is not None:
1138 warnings.warn(
1139 "The SLA feature is removed in Airflow 3.0, replaced with Deadline Alerts in >=3.1",
1140 stacklevel=2,
1141 )
1142
1143 try:
1144 TriggerRule(trigger_rule)
1145 except ValueError:
1146 raise ValueError(
1147 f"The trigger_rule must be one of {[rule.value for rule in TriggerRule]},"
1148 f"'{dag.dag_id if dag else ''}.{task_id}'; received '{trigger_rule}'."
1149 )
1150
1151 self.trigger_rule: TriggerRule = TriggerRule(trigger_rule)
1152
1153 self.depends_on_past: bool = depends_on_past
1154 self.ignore_first_depends_on_past: bool = ignore_first_depends_on_past
1155 self.wait_for_past_depends_before_skipping: bool = wait_for_past_depends_before_skipping
1156 self.wait_for_downstream: bool = wait_for_downstream
1157 if wait_for_downstream:
1158 self.depends_on_past = True
1159
1160 # Converted by setattr
1161 self.retry_delay = retry_delay # type: ignore[assignment]
1162 self.retry_exponential_backoff = retry_exponential_backoff
1163 if max_retry_delay is not None:
1164 self.max_retry_delay = max_retry_delay
1165
1166 self.resources = resources
1167
1168 self.params = ParamsDict(params)
1169
1170 self.priority_weight = priority_weight
1171 self.weight_rule = weight_rule
1172
1173 self.max_active_tis_per_dag: int | None = max_active_tis_per_dag
1174 self.max_active_tis_per_dagrun: int | None = max_active_tis_per_dagrun
1175 self.do_xcom_push: bool = do_xcom_push
1176 self.map_index_template: str | None = map_index_template
1177 self.multiple_outputs: bool = multiple_outputs
1178
1179 self.doc_md = doc_md
1180 self.doc_json = doc_json
1181 self.doc_yaml = doc_yaml
1182 self.doc_rst = doc_rst
1183 self.doc = doc
1184
1185 self._task_display_name = task_display_name
1186
1187 self.allow_nested_operators = allow_nested_operators
1188
1189 self.render_template_as_native_obj = render_template_as_native_obj
1190
1191 self._logger_name = logger_name
1192
1193 # Lineage
1194 self.inlets = _collect_from_input(inlets)
1195 self.outlets = _collect_from_input(outlets)
1196
1197 if isinstance(self.template_fields, str):
1198 warnings.warn(
1199 f"The `template_fields` value for {self.task_type} is a string "
1200 "but should be a list or tuple of string. Wrapping it in a list for execution. "
1201 f"Please update {self.task_type} accordingly.",
1202 UserWarning,
1203 stacklevel=2,
1204 )
1205 self.template_fields = [self.template_fields]
1206
1207 self.is_setup = False
1208 self.is_teardown = False
1209
1210 if SetupTeardownContext.active:
1211 SetupTeardownContext.update_context_map(self)
1212
1213 # We set self.dag right at the end as `_convert_dag` calls `dag.add_task` for us, and we need all the
1214 # other properties to be set at that point
1215 if dag is not None:
1216 self.dag = dag
1217
1218 validate_instance_args(self, BASEOPERATOR_ARGS_EXPECTED_TYPES)
1219
1220 # Ensure priority_weight is within the valid range
1221 self.priority_weight = db_safe_priority(self.priority_weight)
1222
1223 def __eq__(self, other):
1224 if type(self) is type(other):
1225 # Use getattr() instead of __dict__ as __dict__ doesn't return
1226 # correct values for properties.
1227 return all(getattr(self, c, None) == getattr(other, c, None) for c in self._comps)
1228 return False
1229
1230 def __ne__(self, other):
1231 return not self == other
1232
1233 def __hash__(self):
1234 hash_components = [type(self)]
1235 for component in self._comps:
1236 val = getattr(self, component, None)
1237 try:
1238 hash(val)
1239 hash_components.append(val)
1240 except TypeError:
1241 hash_components.append(repr(val))
1242 return hash(tuple(hash_components))
1243
1244 # /Composing Operators ---------------------------------------------
1245
1246 def __gt__(self, other):
1247 """
1248 Return [Operator] > [Outlet].
1249
1250 If other is an attr annotated object it is set as an outlet of this Operator.
1251 """
1252 if not isinstance(other, Iterable):
1253 other = [other]
1254
1255 for obj in other:
1256 if not attrs.has(obj):
1257 raise TypeError(f"Left hand side ({obj}) is not an outlet")
1258 self.add_outlets(other)
1259
1260 return self
1261
1262 def __lt__(self, other):
1263 """
1264 Return [Inlet] > [Operator] or [Operator] < [Inlet].
1265
1266 If other is an attr annotated object it is set as an inlet to this operator.
1267 """
1268 if not isinstance(other, Iterable):
1269 other = [other]
1270
1271 for obj in other:
1272 if not attrs.has(obj):
1273 raise TypeError(f"{obj} cannot be an inlet")
1274 self.add_inlets(other)
1275
1276 return self
1277
1278 def __deepcopy__(self, memo: dict[int, Any]):
1279 # Hack sorting double chained task lists by task_id to avoid hitting
1280 # max_depth on deepcopy operations.
1281 sys.setrecursionlimit(5000) # TODO fix this in a better way
1282
1283 cls = self.__class__
1284 result = cls.__new__(cls)
1285 memo[id(self)] = result
1286
1287 shallow_copy = tuple(cls.shallow_copy_attrs) + cls._base_operator_shallow_copy_attrs
1288
1289 for k, v_org in self.__dict__.items():
1290 if k not in shallow_copy:
1291 v = copy.deepcopy(v_org, memo)
1292 else:
1293 v = copy.copy(v_org)
1294
1295 # Bypass any setters, and set it on the object directly. This works since we are cloning ourself so
1296 # we know the type is already fine
1297 result.__dict__[k] = v
1298 return result
1299
1300 def __getstate__(self):
1301 state = dict(self.__dict__)
1302 if "_log" in state:
1303 del state["_log"]
1304
1305 return state
1306
1307 def __setstate__(self, state):
1308 self.__dict__ = state
1309
1310 def add_inlets(self, inlets: Iterable[Any]):
1311 """Set inlets to this operator."""
1312 self.inlets.extend(inlets)
1313
1314 def add_outlets(self, outlets: Iterable[Any]):
1315 """Define the outlets of this operator."""
1316 self.outlets.extend(outlets)
1317
1318 def get_dag(self) -> DAG | None:
1319 return self._dag
1320
1321 @property
1322 def dag(self) -> DAG:
1323 """Returns the Operator's Dag if set, otherwise raises an error."""
1324 if dag := self._dag:
1325 return dag
1326 raise RuntimeError(f"Operator {self} has not been assigned to a Dag yet")
1327
1328 @dag.setter
1329 def dag(self, dag: DAG | None) -> None:
1330 """Operators can be assigned to one Dag, one time. Repeat assignments to that same Dag are ok."""
1331 self._dag = dag
1332
1333 def _convert__dag(self, dag: DAG | None) -> DAG | None:
1334 # Called automatically by __setattr__ method
1335 from airflow.sdk.definitions.dag import DAG
1336
1337 if dag is None:
1338 return dag
1339
1340 if not isinstance(dag, DAG):
1341 raise TypeError(f"Expected dag; received {dag.__class__.__name__}")
1342 if self._dag is not None and self._dag is not dag:
1343 raise ValueError(f"The dag assigned to {self} can not be changed.")
1344
1345 if self.__from_mapped:
1346 pass # Don't add to dag -- the mapped task takes the place.
1347 elif dag.task_dict.get(self.task_id) is not self:
1348 dag.add_task(self)
1349 return dag
1350
1351 @staticmethod
1352 def _convert_retries(retries: Any) -> int | None:
1353 if retries is None:
1354 return 0
1355 if type(retries) == int: # noqa: E721
1356 return retries
1357 try:
1358 parsed_retries = int(retries)
1359 except (TypeError, ValueError):
1360 raise TypeError(f"'retries' type must be int, not {type(retries).__name__}")
1361 return parsed_retries
1362
1363 @staticmethod
1364 def _convert_timedelta(value: float | timedelta | None) -> timedelta | None:
1365 if value is None or isinstance(value, timedelta):
1366 return value
1367 return timedelta(seconds=value)
1368
1369 _convert_retry_delay = _convert_timedelta
1370 _convert_max_retry_delay = _convert_timedelta
1371
1372 @staticmethod
1373 def _convert_resources(resources: dict[str, Any] | None) -> Resources | None:
1374 if resources is None:
1375 return None
1376
1377 from airflow.sdk.definitions.operator_resources import Resources
1378
1379 if isinstance(resources, Resources):
1380 return resources
1381
1382 return Resources(**resources)
1383
1384 def _convert_is_setup(self, value: bool) -> bool:
1385 """
1386 Setter for is_setup property.
1387
1388 :meta private:
1389 """
1390 if self.is_teardown and value:
1391 raise ValueError(f"Cannot mark task '{self.task_id}' as setup; task is already a teardown.")
1392 return value
1393
1394 def _convert_is_teardown(self, value: bool) -> bool:
1395 if self.is_setup and value:
1396 raise ValueError(f"Cannot mark task '{self.task_id}' as teardown; task is already a setup.")
1397 return value
1398
1399 @property
1400 def task_display_name(self) -> str:
1401 return self._task_display_name or self.task_id
1402
1403 def has_dag(self):
1404 """Return True if the Operator has been assigned to a Dag."""
1405 return self._dag is not None
1406
1407 def _set_xcomargs_dependencies(self) -> None:
1408 from airflow.sdk.definitions.xcom_arg import XComArg
1409
1410 for f in self.template_fields:
1411 arg = getattr(self, f, NOTSET)
1412 if arg is not NOTSET:
1413 XComArg.apply_upstream_relationship(self, arg)
1414
1415 def _set_xcomargs_dependency(self, field: str, newvalue: Any) -> None:
1416 """
1417 Resolve upstream dependencies of a task.
1418
1419 In this way passing an ``XComArg`` as value for a template field
1420 will result in creating upstream relation between two tasks.
1421
1422 **Example**: ::
1423
1424 with DAG(...):
1425 generate_content = GenerateContentOperator(task_id="generate_content")
1426 send_email = EmailOperator(..., html_content=generate_content.output)
1427
1428 # This is equivalent to
1429 with DAG(...):
1430 generate_content = GenerateContentOperator(task_id="generate_content")
1431 send_email = EmailOperator(
1432 ..., html_content="{{ task_instance.xcom_pull('generate_content') }}"
1433 )
1434 generate_content >> send_email
1435
1436 """
1437 from airflow.sdk.definitions.xcom_arg import XComArg
1438
1439 if field not in self.template_fields:
1440 return
1441 XComArg.apply_upstream_relationship(self, newvalue)
1442
1443 def on_kill(self) -> None:
1444 """
1445 Override this method to clean up subprocesses when a task instance gets killed.
1446
1447 Any use of the threading, subprocess or multiprocessing module within an
1448 operator needs to be cleaned up, or it will leave ghost processes behind.
1449 """
1450
1451 def __repr__(self):
1452 return f"<Task({self.task_type}): {self.task_id}>"
1453
1454 @property
1455 def operator_class(self) -> type[BaseOperator]: # type: ignore[override]
1456 return self.__class__
1457
1458 @property
1459 def task_type(self) -> str:
1460 """@property: type of the task."""
1461 return self.__class__.__name__
1462
1463 @property
1464 def operator_name(self) -> str:
1465 """@property: use a more friendly display name for the operator, if set."""
1466 try:
1467 return self.custom_operator_name # type: ignore
1468 except AttributeError:
1469 return self.task_type
1470
1471 @property
1472 def roots(self) -> list[BaseOperator]:
1473 """Required by DAGNode."""
1474 return [self]
1475
1476 @property
1477 def leaves(self) -> list[BaseOperator]:
1478 """Required by DAGNode."""
1479 return [self]
1480
1481 @property
1482 def output(self) -> XComArg:
1483 """Returns reference to XCom pushed by current operator."""
1484 from airflow.sdk.definitions.xcom_arg import XComArg
1485
1486 return XComArg(operator=self)
1487
1488 @classmethod
1489 def get_serialized_fields(cls):
1490 """Stringified Dags and operators contain exactly these fields."""
1491 if not cls.__serialized_fields:
1492 from airflow.sdk.definitions._internal.contextmanager import DagContext
1493
1494 # make sure the following "fake" task is not added to current active
1495 # dag in context, otherwise, it will result in
1496 # `RuntimeError: dictionary changed size during iteration`
1497 # Exception in SerializedDAG.serialize_dag() call.
1498 DagContext.push(None)
1499 cls.__serialized_fields = frozenset(
1500 vars(BaseOperator(task_id="test")).keys()
1501 - {
1502 "upstream_task_ids",
1503 "default_args",
1504 "dag",
1505 "_dag",
1506 "label",
1507 "_BaseOperator__instantiated",
1508 "_BaseOperator__init_kwargs",
1509 "_BaseOperator__from_mapped",
1510 "on_failure_fail_dagrun",
1511 "task_group",
1512 "_task_type",
1513 "operator_extra_links",
1514 "on_execute_callback",
1515 "on_failure_callback",
1516 "on_success_callback",
1517 "on_retry_callback",
1518 "on_skipped_callback",
1519 }
1520 | { # Class level defaults, or `@property` need to be added to this list
1521 "start_date",
1522 "end_date",
1523 "task_type",
1524 "ui_color",
1525 "ui_fgcolor",
1526 "template_ext",
1527 "template_fields",
1528 "template_fields_renderers",
1529 "params",
1530 "is_setup",
1531 "is_teardown",
1532 "on_failure_fail_dagrun",
1533 "map_index_template",
1534 "start_trigger_args",
1535 "_needs_expansion",
1536 "start_from_trigger",
1537 "max_retry_delay",
1538 "has_on_execute_callback",
1539 "has_on_failure_callback",
1540 "has_on_success_callback",
1541 "has_on_retry_callback",
1542 "has_on_skipped_callback",
1543 }
1544 )
1545 DagContext.pop()
1546
1547 return cls.__serialized_fields
1548
1549 def prepare_for_execution(self) -> Self:
1550 """Lock task for execution to disable custom action in ``__setattr__`` and return a copy."""
1551 other = copy.copy(self)
1552 other._lock_for_execution = True
1553 return other
1554
1555 def serialize_for_task_group(self) -> tuple[DagAttributeTypes, Any]:
1556 """Serialize; required by DAGNode."""
1557 from airflow.sdk.api.datamodels._generated import DagAttributeTypes
1558
1559 return DagAttributeTypes.OP, self.task_id
1560
1561 def unmap(self, resolve: None | Mapping[str, Any]) -> Self:
1562 """
1563 Get the "normal" operator from the current operator.
1564
1565 Since a BaseOperator is not mapped to begin with, this simply returns
1566 the original operator.
1567
1568 :meta private:
1569 """
1570 return self
1571
1572 def expand_start_trigger_args(self, *, context: Context) -> StartTriggerArgs | None:
1573 """
1574 Get the start_trigger_args value of the current abstract operator.
1575
1576 Since a BaseOperator is not mapped to begin with, this simply returns
1577 the original value of start_trigger_args.
1578
1579 :meta private:
1580 """
1581 return self.start_trigger_args
1582
1583 def render_template_fields(
1584 self,
1585 context: Context,
1586 jinja_env: jinja2.Environment | None = None,
1587 ) -> None:
1588 """
1589 Template all attributes listed in *self.template_fields*.
1590
1591 This mutates the attributes in-place and is irreversible.
1592
1593 :param context: Context dict with values to apply on content.
1594 :param jinja_env: Jinja's environment to use for rendering.
1595 """
1596 if not jinja_env:
1597 jinja_env = self.get_template_env()
1598 self._do_render_template_fields(self, self.template_fields, context, jinja_env, set())
1599
1600 def pre_execute(self, context: Any):
1601 """Execute right before self.execute() is called."""
1602
1603 def execute(self, context: Context) -> Any:
1604 """
1605 Derive when creating an operator.
1606
1607 The main method to execute the task. Context is the same dictionary used
1608 as when rendering jinja templates.
1609
1610 Refer to get_template_context for more context.
1611 """
1612 raise NotImplementedError()
1613
1614 def post_execute(self, context: Any, result: Any = None):
1615 """
1616 Execute right after self.execute() is called.
1617
1618 It is passed the execution context and any results returned by the operator.
1619 """
1620
1621 def defer(
1622 self,
1623 *,
1624 trigger: BaseTrigger,
1625 method_name: str,
1626 kwargs: dict[str, Any] | None = None,
1627 timeout: timedelta | int | float | None = None,
1628 ) -> NoReturn:
1629 """
1630 Mark this Operator "deferred", suspending its execution until the provided trigger fires an event.
1631
1632 This is achieved by raising a special exception (TaskDeferred)
1633 which is caught in the main _execute_task wrapper. Triggers can send execution back to task or end
1634 the task instance directly. If the trigger will end the task instance itself, ``method_name`` should
1635 be None; otherwise, provide the name of the method that should be used when resuming execution in
1636 the task.
1637 """
1638 from airflow.sdk.exceptions import TaskDeferred
1639
1640 raise TaskDeferred(trigger=trigger, method_name=method_name, kwargs=kwargs, timeout=timeout)
1641
1642 def resume_execution(self, next_method: str, next_kwargs: dict[str, Any] | None, context: Context):
1643 """Entrypoint method called by the Task Runner (instead of execute) when this task is resumed."""
1644 from airflow.sdk.exceptions import TaskDeferralError, TaskDeferralTimeout
1645
1646 if next_kwargs is None:
1647 next_kwargs = {}
1648 # __fail__ is a special signal value for next_method that indicates
1649 # this task was scheduled specifically to fail.
1650
1651 if next_method == TRIGGER_FAIL_REPR:
1652 next_kwargs = next_kwargs or {}
1653 traceback = next_kwargs.get("traceback")
1654 if traceback is not None:
1655 self.log.error("Trigger failed:\n%s", "\n".join(traceback))
1656 if (error := next_kwargs.get("error", "Unknown")) == TriggerFailureReason.TRIGGER_TIMEOUT:
1657 raise TaskDeferralTimeout(error)
1658 raise TaskDeferralError(error)
1659 # Grab the callable off the Operator/Task and add in any kwargs
1660 execute_callable = getattr(self, next_method)
1661 return execute_callable(context, **next_kwargs)
1662
1663 def dry_run(self) -> None:
1664 """Perform dry run for the operator - just render template fields."""
1665 self.log.info("Dry run")
1666 for f in self.template_fields:
1667 try:
1668 content = getattr(self, f)
1669 except AttributeError:
1670 raise AttributeError(
1671 f"{f!r} is configured as a template field "
1672 f"but {self.task_type} does not have this attribute."
1673 )
1674
1675 if content and isinstance(content, str):
1676 self.log.info("Rendering template for %s", f)
1677 self.log.info(content)
1678
1679 @property
1680 def has_on_execute_callback(self) -> bool:
1681 """Return True if the task has execute callbacks."""
1682 return bool(self.on_execute_callback)
1683
1684 @property
1685 def has_on_failure_callback(self) -> bool:
1686 """Return True if the task has failure callbacks."""
1687 return bool(self.on_failure_callback)
1688
1689 @property
1690 def has_on_success_callback(self) -> bool:
1691 """Return True if the task has success callbacks."""
1692 return bool(self.on_success_callback)
1693
1694 @property
1695 def has_on_retry_callback(self) -> bool:
1696 """Return True if the task has retry callbacks."""
1697 return bool(self.on_retry_callback)
1698
1699 @property
1700 def has_on_skipped_callback(self) -> bool:
1701 """Return True if the task has skipped callbacks."""
1702 return bool(self.on_skipped_callback)
1703
1704
1705class BaseAsyncOperator(BaseOperator):
1706 """
1707 Base class for async-capable operators.
1708
1709 As opposed to deferred operators which are executed on the triggerer, async operators are executed
1710 on the worker.
1711 """
1712
1713 @property
1714 def is_async(self) -> bool:
1715 return True
1716
1717 async def aexecute(self, context):
1718 """Async version of execute(). Subclasses should implement this."""
1719 raise NotImplementedError()
1720
1721 def execute(self, context):
1722 """Run `aexecute()` inside an event loop."""
1723 with event_loop() as loop:
1724 if self.execution_timeout:
1725 return loop.run_until_complete(
1726 asyncio.wait_for(
1727 self.aexecute(context),
1728 timeout=self.execution_timeout.total_seconds(),
1729 )
1730 )
1731 return loop.run_until_complete(self.aexecute(context))
1732
1733
1734def chain(*tasks: DependencyMixin | Sequence[DependencyMixin]) -> None:
1735 r"""
1736 Given a number of tasks, builds a dependency chain.
1737
1738 This function accepts values of BaseOperator (aka tasks), EdgeModifiers (aka Labels), XComArg, TaskGroups,
1739 or lists containing any mix of these types (or a mix in the same list). If you want to chain between two
1740 lists you must ensure they have the same length.
1741
1742 Using classic operators/sensors:
1743
1744 .. code-block:: python
1745
1746 chain(t1, [t2, t3], [t4, t5], t6)
1747
1748 is equivalent to::
1749
1750 / -> t2 -> t4 \
1751 t1 -> t6
1752 \ -> t3 -> t5 /
1753
1754 .. code-block:: python
1755
1756 t1.set_downstream(t2)
1757 t1.set_downstream(t3)
1758 t2.set_downstream(t4)
1759 t3.set_downstream(t5)
1760 t4.set_downstream(t6)
1761 t5.set_downstream(t6)
1762
1763 Using task-decorated functions aka XComArgs:
1764
1765 .. code-block:: python
1766
1767 chain(x1(), [x2(), x3()], [x4(), x5()], x6())
1768
1769 is equivalent to::
1770
1771 / -> x2 -> x4 \
1772 x1 -> x6
1773 \ -> x3 -> x5 /
1774
1775 .. code-block:: python
1776
1777 x1 = x1()
1778 x2 = x2()
1779 x3 = x3()
1780 x4 = x4()
1781 x5 = x5()
1782 x6 = x6()
1783 x1.set_downstream(x2)
1784 x1.set_downstream(x3)
1785 x2.set_downstream(x4)
1786 x3.set_downstream(x5)
1787 x4.set_downstream(x6)
1788 x5.set_downstream(x6)
1789
1790 Using TaskGroups:
1791
1792 .. code-block:: python
1793
1794 chain(t1, task_group1, task_group2, t2)
1795
1796 t1.set_downstream(task_group1)
1797 task_group1.set_downstream(task_group2)
1798 task_group2.set_downstream(t2)
1799
1800
1801 It is also possible to mix between classic operator/sensor, EdgeModifiers, XComArg, and TaskGroups:
1802
1803 .. code-block:: python
1804
1805 chain(t1, [Label("branch one"), Label("branch two")], [x1(), x2()], task_group1, x3())
1806
1807 is equivalent to::
1808
1809 / "branch one" -> x1 \
1810 t1 -> task_group1 -> x3
1811 \ "branch two" -> x2 /
1812
1813 .. code-block:: python
1814
1815 x1 = x1()
1816 x2 = x2()
1817 x3 = x3()
1818 label1 = Label("branch one")
1819 label2 = Label("branch two")
1820 t1.set_downstream(label1)
1821 label1.set_downstream(x1)
1822 t2.set_downstream(label2)
1823 label2.set_downstream(x2)
1824 x1.set_downstream(task_group1)
1825 x2.set_downstream(task_group1)
1826 task_group1.set_downstream(x3)
1827
1828 # or
1829
1830 x1 = x1()
1831 x2 = x2()
1832 x3 = x3()
1833 t1.set_downstream(x1, edge_modifier=Label("branch one"))
1834 t1.set_downstream(x2, edge_modifier=Label("branch two"))
1835 x1.set_downstream(task_group1)
1836 x2.set_downstream(task_group1)
1837 task_group1.set_downstream(x3)
1838
1839
1840 :param tasks: Individual and/or list of tasks, EdgeModifiers, XComArgs, or TaskGroups to set dependencies
1841 """
1842 for up_task, down_task in zip(tasks, tasks[1:]):
1843 if isinstance(up_task, DependencyMixin):
1844 up_task.set_downstream(down_task)
1845 continue
1846 if isinstance(down_task, DependencyMixin):
1847 down_task.set_upstream(up_task)
1848 continue
1849 if not isinstance(up_task, Sequence) or not isinstance(down_task, Sequence):
1850 raise TypeError(f"Chain not supported between instances of {type(up_task)} and {type(down_task)}")
1851 up_task_list = up_task
1852 down_task_list = down_task
1853 if len(up_task_list) != len(down_task_list):
1854 raise ValueError(
1855 f"Chain not supported for different length Iterable. "
1856 f"Got {len(up_task_list)} and {len(down_task_list)}."
1857 )
1858 for up_t, down_t in zip(up_task_list, down_task_list):
1859 up_t.set_downstream(down_t)
1860
1861
1862def cross_downstream(
1863 from_tasks: Sequence[DependencyMixin],
1864 to_tasks: DependencyMixin | Sequence[DependencyMixin],
1865):
1866 r"""
1867 Set downstream dependencies for all tasks in from_tasks to all tasks in to_tasks.
1868
1869 Using classic operators/sensors:
1870
1871 .. code-block:: python
1872
1873 cross_downstream(from_tasks=[t1, t2, t3], to_tasks=[t4, t5, t6])
1874
1875 is equivalent to::
1876
1877 t1 ---> t4
1878 \ /
1879 t2 -X -> t5
1880 / \
1881 t3 ---> t6
1882
1883 .. code-block:: python
1884
1885 t1.set_downstream(t4)
1886 t1.set_downstream(t5)
1887 t1.set_downstream(t6)
1888 t2.set_downstream(t4)
1889 t2.set_downstream(t5)
1890 t2.set_downstream(t6)
1891 t3.set_downstream(t4)
1892 t3.set_downstream(t5)
1893 t3.set_downstream(t6)
1894
1895 Using task-decorated functions aka XComArgs:
1896
1897 .. code-block:: python
1898
1899 cross_downstream(from_tasks=[x1(), x2(), x3()], to_tasks=[x4(), x5(), x6()])
1900
1901 is equivalent to::
1902
1903 x1 ---> x4
1904 \ /
1905 x2 -X -> x5
1906 / \
1907 x3 ---> x6
1908
1909 .. code-block:: python
1910
1911 x1 = x1()
1912 x2 = x2()
1913 x3 = x3()
1914 x4 = x4()
1915 x5 = x5()
1916 x6 = x6()
1917 x1.set_downstream(x4)
1918 x1.set_downstream(x5)
1919 x1.set_downstream(x6)
1920 x2.set_downstream(x4)
1921 x2.set_downstream(x5)
1922 x2.set_downstream(x6)
1923 x3.set_downstream(x4)
1924 x3.set_downstream(x5)
1925 x3.set_downstream(x6)
1926
1927 It is also possible to mix between classic operator/sensor and XComArg tasks:
1928
1929 .. code-block:: python
1930
1931 cross_downstream(from_tasks=[t1, x2(), t3], to_tasks=[x1(), t2, x3()])
1932
1933 is equivalent to::
1934
1935 t1 ---> x1
1936 \ /
1937 x2 -X -> t2
1938 / \
1939 t3 ---> x3
1940
1941 .. code-block:: python
1942
1943 x1 = x1()
1944 x2 = x2()
1945 x3 = x3()
1946 t1.set_downstream(x1)
1947 t1.set_downstream(t2)
1948 t1.set_downstream(x3)
1949 x2.set_downstream(x1)
1950 x2.set_downstream(t2)
1951 x2.set_downstream(x3)
1952 t3.set_downstream(x1)
1953 t3.set_downstream(t2)
1954 t3.set_downstream(x3)
1955
1956 :param from_tasks: List of tasks or XComArgs to start from.
1957 :param to_tasks: List of tasks or XComArgs to set as downstream dependencies.
1958 """
1959 for task in from_tasks:
1960 task.set_downstream(to_tasks)
1961
1962
1963def chain_linear(*elements: DependencyMixin | Sequence[DependencyMixin]):
1964 """
1965 Simplify task dependency definition.
1966
1967 E.g.: suppose you want precedence like so::
1968
1969 ╭─op2─╮ ╭─op4─╮
1970 op1─┤ ├─├─op5─┤─op7
1971 ╰-op3─╯ ╰-op6─╯
1972
1973 Then you can accomplish like so::
1974
1975 chain_linear(op1, [op2, op3], [op4, op5, op6], op7)
1976
1977 :param elements: a list of operators / lists of operators
1978 """
1979 if not elements:
1980 raise ValueError("No tasks provided; nothing to do.")
1981 prev_elem = None
1982 deps_set = False
1983 for curr_elem in elements:
1984 if isinstance(curr_elem, EdgeModifier):
1985 raise ValueError("Labels are not supported by chain_linear")
1986 if prev_elem is not None:
1987 for task in prev_elem:
1988 task >> curr_elem
1989 if not deps_set:
1990 deps_set = True
1991 prev_elem = [curr_elem] if isinstance(curr_elem, DependencyMixin) else curr_elem
1992 if not deps_set:
1993 raise ValueError("No dependencies were set. Did you forget to expand with `*`?")