1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements. See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership. The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License. You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied. See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18from __future__ import annotations
19
20import abc
21import collections.abc
22import contextlib
23import copy
24import inspect
25import sys
26import warnings
27from collections.abc import Callable, Collection, Iterable, Mapping, Sequence
28from contextvars import ContextVar
29from dataclasses import dataclass, field
30from datetime import datetime, timedelta
31from enum import Enum
32from functools import total_ordering, wraps
33from types import FunctionType
34from typing import TYPE_CHECKING, Any, ClassVar, Final, NoReturn, TypeVar, cast
35
36import attrs
37
38from airflow.sdk import TriggerRule, timezone
39from airflow.sdk._shared.secrets_masker import redact
40from airflow.sdk.definitions._internal.abstractoperator import (
41 DEFAULT_EMAIL_ON_FAILURE,
42 DEFAULT_EMAIL_ON_RETRY,
43 DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST,
44 DEFAULT_OWNER,
45 DEFAULT_POOL_NAME,
46 DEFAULT_POOL_SLOTS,
47 DEFAULT_PRIORITY_WEIGHT,
48 DEFAULT_QUEUE,
49 DEFAULT_RETRIES,
50 DEFAULT_RETRY_DELAY,
51 DEFAULT_TASK_EXECUTION_TIMEOUT,
52 DEFAULT_TRIGGER_RULE,
53 DEFAULT_WAIT_FOR_PAST_DEPENDS_BEFORE_SKIPPING,
54 DEFAULT_WEIGHT_RULE,
55 AbstractOperator,
56 DependencyMixin,
57 TaskStateChangeCallback,
58)
59from airflow.sdk.definitions._internal.decorators import fixup_decorator_warning_stack
60from airflow.sdk.definitions._internal.node import validate_key
61from airflow.sdk.definitions._internal.setup_teardown import SetupTeardownContext
62from airflow.sdk.definitions._internal.types import NOTSET, validate_instance_args
63from airflow.sdk.definitions.edges import EdgeModifier
64from airflow.sdk.definitions.mappedoperator import OperatorPartial, validate_mapping_kwargs
65from airflow.sdk.definitions.param import ParamsDict
66from airflow.sdk.exceptions import RemovedInAirflow4Warning
67
68# Databases do not support arbitrary precision integers, so we need to limit the range of priority weights.
69# postgres: -2147483648 to +2147483647 (see https://www.postgresql.org/docs/current/datatype-numeric.html)
70# mysql: -2147483648 to +2147483647 (see https://dev.mysql.com/doc/refman/8.4/en/integer-types.html)
71# sqlite: -9223372036854775808 to +9223372036854775807 (see https://sqlite.org/datatype3.html)
72DB_SAFE_MINIMUM = -2147483648
73DB_SAFE_MAXIMUM = 2147483647
74
75
76def db_safe_priority(priority_weight: int) -> int:
77 """Convert priority weight to a safe value for the database."""
78 return max(DB_SAFE_MINIMUM, min(DB_SAFE_MAXIMUM, priority_weight))
79
80
81C = TypeVar("C", bound=Callable)
82T = TypeVar("T", bound=FunctionType)
83
84if TYPE_CHECKING:
85 from types import ClassMethodDescriptorType
86
87 import jinja2
88 from typing_extensions import Self
89
90 from airflow.sdk.bases.operatorlink import BaseOperatorLink
91 from airflow.sdk.definitions.context import Context
92 from airflow.sdk.definitions.dag import DAG
93 from airflow.sdk.definitions.operator_resources import Resources
94 from airflow.sdk.definitions.taskgroup import TaskGroup
95 from airflow.sdk.definitions.xcom_arg import XComArg
96 from airflow.serialization.enums import DagAttributeTypes
97 from airflow.task.priority_strategy import PriorityWeightStrategy
98 from airflow.triggers.base import BaseTrigger, StartTriggerArgs
99
100 TaskPreExecuteHook = Callable[[Context], None]
101 TaskPostExecuteHook = Callable[[Context, Any], None]
102
103__all__ = [
104 "BaseOperator",
105 "chain",
106 "chain_linear",
107 "cross_downstream",
108]
109
110
111class TriggerFailureReason(str, Enum):
112 """
113 Reasons for trigger failures.
114
115 Internal use only.
116
117 :meta private:
118 """
119
120 TRIGGER_TIMEOUT = "Trigger timeout"
121 TRIGGER_FAILURE = "Trigger failure"
122
123
124TRIGGER_FAIL_REPR = "__fail__"
125"""String value to represent trigger failure.
126
127Internal use only.
128
129:meta private:
130"""
131
132
133def _get_parent_defaults(dag: DAG | None, task_group: TaskGroup | None) -> tuple[dict, ParamsDict]:
134 if not dag:
135 return {}, ParamsDict()
136 dag_args = copy.copy(dag.default_args)
137 dag_params = copy.deepcopy(dag.params)
138 dag_params._fill_missing_param_source("dag")
139 if task_group:
140 if task_group.default_args and not isinstance(task_group.default_args, collections.abc.Mapping):
141 raise TypeError("default_args must be a mapping")
142 dag_args.update(task_group.default_args)
143 return dag_args, dag_params
144
145
146def get_merged_defaults(
147 dag: DAG | None,
148 task_group: TaskGroup | None,
149 task_params: collections.abc.MutableMapping | None,
150 task_default_args: dict | None,
151) -> tuple[dict, ParamsDict]:
152 args, params = _get_parent_defaults(dag, task_group)
153 if task_params:
154 if not isinstance(task_params, collections.abc.Mapping):
155 raise TypeError(f"params must be a mapping, got {type(task_params)}")
156
157 task_params = ParamsDict(task_params)
158 task_params._fill_missing_param_source("task")
159 params.update(task_params)
160
161 if task_default_args:
162 if not isinstance(task_default_args, collections.abc.Mapping):
163 raise TypeError(f"default_args must be a mapping, got {type(task_params)}")
164 args.update(task_default_args)
165 with contextlib.suppress(KeyError):
166 if params_from_default_args := ParamsDict(task_default_args["params"] or {}):
167 params_from_default_args._fill_missing_param_source("task")
168 params.update(params_from_default_args)
169
170 return args, params
171
172
173def parse_retries(retries: Any) -> int | None:
174 if retries is None:
175 return 0
176 if type(retries) == int: # noqa: E721
177 return retries
178 try:
179 parsed_retries = int(retries)
180 except (TypeError, ValueError):
181 raise RuntimeError(f"'retries' type must be int, not {type(retries).__name__}")
182 return parsed_retries
183
184
185def coerce_timedelta(value: float | timedelta, *, key: str | None = None) -> timedelta:
186 if isinstance(value, timedelta):
187 return value
188 return timedelta(seconds=value)
189
190
191def coerce_resources(resources: dict[str, Any] | None) -> Resources | None:
192 if resources is None:
193 return None
194 from airflow.sdk.definitions.operator_resources import Resources
195
196 return Resources(**resources)
197
198
199class _PartialDescriptor:
200 """A descriptor that guards against ``.partial`` being called on Task objects."""
201
202 class_method: ClassMethodDescriptorType | None = None
203
204 def __get__(
205 self, obj: BaseOperator, cls: type[BaseOperator] | None = None
206 ) -> Callable[..., OperatorPartial]:
207 # Call this "partial" so it looks nicer in stack traces.
208 def partial(**kwargs):
209 raise TypeError("partial can only be called on Operator classes, not Tasks themselves")
210
211 if obj is not None:
212 return partial
213 return self.class_method.__get__(cls, cls)
214
215
216OPERATOR_DEFAULTS: dict[str, Any] = {
217 "allow_nested_operators": True,
218 "depends_on_past": False,
219 "email_on_failure": DEFAULT_EMAIL_ON_FAILURE,
220 "email_on_retry": DEFAULT_EMAIL_ON_RETRY,
221 "execution_timeout": DEFAULT_TASK_EXECUTION_TIMEOUT,
222 # "executor": DEFAULT_EXECUTOR,
223 "executor_config": {},
224 "ignore_first_depends_on_past": DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST,
225 "inlets": [],
226 "map_index_template": None,
227 "on_execute_callback": [],
228 "on_failure_callback": [],
229 "on_retry_callback": [],
230 "on_skipped_callback": [],
231 "on_success_callback": [],
232 "outlets": [],
233 "owner": DEFAULT_OWNER,
234 "pool_slots": DEFAULT_POOL_SLOTS,
235 "priority_weight": DEFAULT_PRIORITY_WEIGHT,
236 "queue": DEFAULT_QUEUE,
237 "retries": DEFAULT_RETRIES,
238 "retry_delay": DEFAULT_RETRY_DELAY,
239 "retry_exponential_backoff": 0,
240 "trigger_rule": DEFAULT_TRIGGER_RULE,
241 "wait_for_past_depends_before_skipping": DEFAULT_WAIT_FOR_PAST_DEPENDS_BEFORE_SKIPPING,
242 "wait_for_downstream": False,
243 "weight_rule": DEFAULT_WEIGHT_RULE,
244}
245
246
247# This is what handles the actual mapping.
248
249if TYPE_CHECKING:
250
251 def partial(
252 operator_class: type[BaseOperator],
253 *,
254 task_id: str,
255 dag: DAG | None = None,
256 task_group: TaskGroup | None = None,
257 start_date: datetime = ...,
258 end_date: datetime = ...,
259 owner: str = ...,
260 email: None | str | Iterable[str] = ...,
261 params: collections.abc.MutableMapping | None = None,
262 resources: dict[str, Any] | None = ...,
263 trigger_rule: str = ...,
264 depends_on_past: bool = ...,
265 ignore_first_depends_on_past: bool = ...,
266 wait_for_past_depends_before_skipping: bool = ...,
267 wait_for_downstream: bool = ...,
268 retries: int | None = ...,
269 queue: str = ...,
270 pool: str = ...,
271 pool_slots: int = ...,
272 execution_timeout: timedelta | None = ...,
273 max_retry_delay: None | timedelta | float = ...,
274 retry_delay: timedelta | float = ...,
275 retry_exponential_backoff: float = ...,
276 priority_weight: int = ...,
277 weight_rule: str | PriorityWeightStrategy = ...,
278 sla: timedelta | None = ...,
279 map_index_template: str | None = ...,
280 max_active_tis_per_dag: int | None = ...,
281 max_active_tis_per_dagrun: int | None = ...,
282 on_execute_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = ...,
283 on_failure_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = ...,
284 on_success_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = ...,
285 on_retry_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = ...,
286 on_skipped_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = ...,
287 run_as_user: str | None = ...,
288 executor: str | None = ...,
289 executor_config: dict | None = ...,
290 inlets: Any | None = ...,
291 outlets: Any | None = ...,
292 doc: str | None = ...,
293 doc_md: str | None = ...,
294 doc_json: str | None = ...,
295 doc_yaml: str | None = ...,
296 doc_rst: str | None = ...,
297 task_display_name: str | None = ...,
298 logger_name: str | None = ...,
299 allow_nested_operators: bool = True,
300 **kwargs,
301 ) -> OperatorPartial: ...
302else:
303
304 def partial(
305 operator_class: type[BaseOperator],
306 *,
307 task_id: str,
308 dag: DAG | None = None,
309 task_group: TaskGroup | None = None,
310 params: collections.abc.MutableMapping | None = None,
311 **kwargs,
312 ):
313 from airflow.sdk.definitions._internal.contextmanager import DagContext, TaskGroupContext
314
315 validate_mapping_kwargs(operator_class, "partial", kwargs)
316
317 dag = dag or DagContext.get_current()
318 if dag:
319 task_group = task_group or TaskGroupContext.get_current(dag)
320 if task_group:
321 task_id = task_group.child_id(task_id)
322
323 # Merge Dag and task group level defaults into user-supplied values.
324 dag_default_args, partial_params = get_merged_defaults(
325 dag=dag,
326 task_group=task_group,
327 task_params=params,
328 task_default_args=kwargs.pop("default_args", None),
329 )
330
331 # Create partial_kwargs from args and kwargs
332 partial_kwargs: dict[str, Any] = {
333 "task_id": task_id,
334 "dag": dag,
335 "task_group": task_group,
336 **kwargs,
337 }
338
339 # Inject Dag-level default args into args provided to this function.
340 # Most of the default args will be retrieved during unmapping; here we
341 # only ensure base properties are correctly set for the scheduler.
342 partial_kwargs.update(
343 (k, v)
344 for k, v in dag_default_args.items()
345 if k not in partial_kwargs and k in BaseOperator.__init__._BaseOperatorMeta__param_names
346 )
347
348 # Fill fields not provided by the user with default values.
349 partial_kwargs.update((k, v) for k, v in OPERATOR_DEFAULTS.items() if k not in partial_kwargs)
350
351 # Post-process arguments. Should be kept in sync with _TaskDecorator.expand().
352 if "task_concurrency" in kwargs: # Reject deprecated option.
353 raise TypeError("unexpected argument: task_concurrency")
354 if start_date := partial_kwargs.get("start_date", None):
355 partial_kwargs["start_date"] = timezone.convert_to_utc(start_date)
356 if end_date := partial_kwargs.get("end_date", None):
357 partial_kwargs["end_date"] = timezone.convert_to_utc(end_date)
358 if partial_kwargs["pool_slots"] < 1:
359 dag_str = ""
360 if dag:
361 dag_str = f" in dag {dag.dag_id}"
362 raise ValueError(f"pool slots for {task_id}{dag_str} cannot be less than 1")
363 if retries := partial_kwargs.get("retries"):
364 partial_kwargs["retries"] = BaseOperator._convert_retries(retries)
365 partial_kwargs["retry_delay"] = BaseOperator._convert_retry_delay(partial_kwargs["retry_delay"])
366 partial_kwargs["max_retry_delay"] = BaseOperator._convert_max_retry_delay(
367 partial_kwargs.get("max_retry_delay", None)
368 )
369
370 for k in ("execute", "failure", "success", "retry", "skipped"):
371 partial_kwargs[attr] = _collect_from_input(partial_kwargs.get(attr := f"on_{k}_callback"))
372
373 return OperatorPartial(
374 operator_class=operator_class,
375 kwargs=partial_kwargs,
376 params=partial_params,
377 )
378
379
380class ExecutorSafeguard:
381 """
382 The ExecutorSafeguard decorator.
383
384 Checks if the execute method of an operator isn't manually called outside
385 the TaskInstance as we want to avoid bad mixing between decorated and
386 classic operators.
387 """
388
389 test_mode: ClassVar[bool] = False
390 tracker: ClassVar[ContextVar[BaseOperator]] = ContextVar("ExecutorSafeguard_sentinel")
391 sentinel_value: ClassVar[object] = object()
392
393 @classmethod
394 def decorator(cls, func):
395 @wraps(func)
396 def wrapper(self, *args, **kwargs):
397 sentinel_key = f"{self.__class__.__name__}__sentinel"
398 sentinel = kwargs.pop(sentinel_key, None)
399
400 with contextlib.ExitStack() as stack:
401 if sentinel is cls.sentinel_value:
402 token = cls.tracker.set(self)
403 sentinel = self
404 stack.callback(cls.tracker.reset, token)
405 else:
406 # No sentinel passed in, maybe the subclass execute did have it passed?
407 sentinel = cls.tracker.get(None)
408
409 if not cls.test_mode and sentinel is not self:
410 message = f"{self.__class__.__name__}.{func.__name__} cannot be called outside of the Task Runner!"
411 if not self.allow_nested_operators:
412 raise RuntimeError(message)
413 self.log.warning(message)
414
415 # Now that we've logged, set sentinel so that `super()` calls don't log again
416 token = cls.tracker.set(self)
417 stack.callback(cls.tracker.reset, token)
418
419 return func(self, *args, **kwargs)
420
421 return wrapper
422
423
424if "airflow.configuration" in sys.modules:
425 # Don't try and import it if it's not already loaded
426 from airflow.sdk.configuration import conf
427
428 ExecutorSafeguard.test_mode = conf.getboolean("core", "unit_test_mode")
429
430
431def _collect_from_input(value_or_values: None | C | Collection[C]) -> list[C]:
432 if not value_or_values:
433 return []
434 if isinstance(value_or_values, Collection):
435 return list(value_or_values)
436 return [value_or_values]
437
438
439class BaseOperatorMeta(abc.ABCMeta):
440 """Metaclass of BaseOperator."""
441
442 @classmethod
443 def _apply_defaults(cls, func: T) -> T:
444 """
445 Look for an argument named "default_args", and fill the unspecified arguments from it.
446
447 Since python2.* isn't clear about which arguments are missing when
448 calling a function, and that this can be quite confusing with multi-level
449 inheritance and argument defaults, this decorator also alerts with
450 specific information about the missing arguments.
451 """
452 # Cache inspect.signature for the wrapper closure to avoid calling it
453 # at every decorated invocation. This is separate sig_cache created
454 # per decoration, i.e. each function decorated using apply_defaults will
455 # have a different sig_cache.
456 sig_cache = inspect.signature(func)
457 non_variadic_params = {
458 name: param
459 for (name, param) in sig_cache.parameters.items()
460 if param.name != "self" and param.kind not in (param.VAR_POSITIONAL, param.VAR_KEYWORD)
461 }
462 non_optional_args = {
463 name
464 for name, param in non_variadic_params.items()
465 if param.default == param.empty and name != "task_id"
466 }
467
468 fixup_decorator_warning_stack(func)
469
470 @wraps(func)
471 def apply_defaults(self: BaseOperator, *args: Any, **kwargs: Any) -> Any:
472 from airflow.sdk.definitions._internal.contextmanager import DagContext, TaskGroupContext
473
474 if args:
475 raise TypeError("Use keyword arguments when initializing operators")
476
477 instantiated_from_mapped = kwargs.pop(
478 "_airflow_from_mapped",
479 getattr(self, "_BaseOperator__from_mapped", False),
480 )
481
482 dag: DAG | None = kwargs.get("dag")
483 if dag is None:
484 dag = DagContext.get_current()
485 if dag is not None:
486 kwargs["dag"] = dag
487
488 task_group: TaskGroup | None = kwargs.get("task_group")
489 if dag and not task_group:
490 task_group = TaskGroupContext.get_current(dag)
491 if task_group is not None:
492 kwargs["task_group"] = task_group
493
494 default_args, merged_params = get_merged_defaults(
495 dag=dag,
496 task_group=task_group,
497 task_params=kwargs.pop("params", None),
498 task_default_args=kwargs.pop("default_args", None),
499 )
500
501 for arg in sig_cache.parameters:
502 if arg not in kwargs and arg in default_args:
503 kwargs[arg] = default_args[arg]
504
505 missing_args = non_optional_args.difference(kwargs)
506 if len(missing_args) == 1:
507 raise TypeError(f"missing keyword argument {missing_args.pop()!r}")
508 if missing_args:
509 display = ", ".join(repr(a) for a in sorted(missing_args))
510 raise TypeError(f"missing keyword arguments {display}")
511
512 if merged_params:
513 kwargs["params"] = merged_params
514
515 hook = getattr(self, "_hook_apply_defaults", None)
516 if hook:
517 args, kwargs = hook(**kwargs, default_args=default_args)
518 default_args = kwargs.pop("default_args", {})
519
520 if not hasattr(self, "_BaseOperator__init_kwargs"):
521 object.__setattr__(self, "_BaseOperator__init_kwargs", {})
522 object.__setattr__(self, "_BaseOperator__from_mapped", instantiated_from_mapped)
523
524 result = func(self, **kwargs, default_args=default_args)
525
526 # Store the args passed to init -- we need them to support task.map serialization!
527 self._BaseOperator__init_kwargs.update(kwargs) # type: ignore
528
529 # Set upstream task defined by XComArgs passed to template fields of the operator.
530 # BUT: only do this _ONCE_, not once for each class in the hierarchy
531 if not instantiated_from_mapped and func == self.__init__.__wrapped__: # type: ignore[misc]
532 self._set_xcomargs_dependencies()
533 # Mark instance as instantiated so that future attr setting updates xcomarg-based deps.
534 object.__setattr__(self, "_BaseOperator__instantiated", True)
535
536 return result
537
538 apply_defaults.__non_optional_args = non_optional_args # type: ignore
539 apply_defaults.__param_names = set(non_variadic_params) # type: ignore
540
541 return cast("T", apply_defaults)
542
543 def __new__(cls, name, bases, namespace, **kwargs):
544 execute_method = namespace.get("execute")
545 if callable(execute_method) and not getattr(execute_method, "__isabstractmethod__", False):
546 namespace["execute"] = ExecutorSafeguard.decorator(execute_method)
547 new_cls = super().__new__(cls, name, bases, namespace, **kwargs)
548 with contextlib.suppress(KeyError):
549 # Update the partial descriptor with the class method, so it calls the actual function
550 # (but let subclasses override it if they need to)
551 partial_desc = vars(new_cls)["partial"]
552 if isinstance(partial_desc, _PartialDescriptor):
553 partial_desc.class_method = classmethod(partial)
554
555 # We patch `__init__` only if the class defines it.
556 first_superclass = new_cls.mro()[1]
557 if new_cls.__init__ is not first_superclass.__init__:
558 new_cls.__init__ = cls._apply_defaults(new_cls.__init__)
559
560 return new_cls
561
562
563# TODO: The following mapping is used to validate that the arguments passed to the BaseOperator are of the
564# correct type. This is a temporary solution until we find a more sophisticated method for argument
565# validation. One potential method is to use `get_type_hints` from the typing module. However, this is not
566# fully compatible with future annotations for Python versions below 3.10. Once we require a minimum Python
567# version that supports `get_type_hints` effectively or find a better approach, we can replace this
568# manual type-checking method.
569BASEOPERATOR_ARGS_EXPECTED_TYPES = {
570 "task_id": str,
571 "email": (str, Sequence),
572 "email_on_retry": bool,
573 "email_on_failure": bool,
574 "retries": int,
575 "retry_exponential_backoff": (int, float),
576 "depends_on_past": bool,
577 "ignore_first_depends_on_past": bool,
578 "wait_for_past_depends_before_skipping": bool,
579 "wait_for_downstream": bool,
580 "priority_weight": int,
581 "queue": str,
582 "pool": str,
583 "pool_slots": int,
584 "trigger_rule": str,
585 "run_as_user": str,
586 "task_concurrency": int,
587 "map_index_template": str,
588 "max_active_tis_per_dag": int,
589 "max_active_tis_per_dagrun": int,
590 "executor": str,
591 "do_xcom_push": bool,
592 "multiple_outputs": bool,
593 "doc": str,
594 "doc_md": str,
595 "doc_json": str,
596 "doc_yaml": str,
597 "doc_rst": str,
598 "task_display_name": str,
599 "logger_name": str,
600 "allow_nested_operators": bool,
601 "start_date": datetime,
602 "end_date": datetime,
603}
604
605
606# Note: BaseOperator is defined as a dataclass, and not an attrs class as we do too much metaprogramming in
607# here (metaclass, custom `__setattr__` behaviour) and this fights with attrs too much to make it worth it.
608#
609# To future reader: if you want to try and make this a "normal" attrs class, go ahead and attempt it. If you
610# get nowhere leave your record here for the next poor soul and what problems you ran in to.
611#
612# @ashb, 2024/10/14
613# - "Can't combine custom __setattr__ with on_setattr hooks"
614# - Setting class-wide `define(on_setarrs=...)` isn't called for non-attrs subclasses
615@total_ordering
616@dataclass(repr=False)
617class BaseOperator(AbstractOperator, metaclass=BaseOperatorMeta):
618 r"""
619 Abstract base class for all operators.
620
621 Since operators create objects that become nodes in the Dag, BaseOperator
622 contains many recursive methods for Dag crawling behavior. To derive from
623 this class, you are expected to override the constructor and the 'execute'
624 method.
625
626 Operators derived from this class should perform or trigger certain tasks
627 synchronously (wait for completion). Example of operators could be an
628 operator that runs a Pig job (PigOperator), a sensor operator that
629 waits for a partition to land in Hive (HiveSensorOperator), or one that
630 moves data from Hive to MySQL (Hive2MySqlOperator). Instances of these
631 operators (tasks) target specific operations, running specific scripts,
632 functions or data transfers.
633
634 This class is abstract and shouldn't be instantiated. Instantiating a
635 class derived from this one results in the creation of a task object,
636 which ultimately becomes a node in Dag objects. Task dependencies should
637 be set by using the set_upstream and/or set_downstream methods.
638
639 :param task_id: a unique, meaningful id for the task
640 :param owner: the owner of the task. Using a meaningful description
641 (e.g. user/person/team/role name) to clarify ownership is recommended.
642 :param email: the 'to' email address(es) used in email alerts. This can be a
643 single email or multiple ones. Multiple addresses can be specified as a
644 comma or semicolon separated string or by passing a list of strings. (deprecated)
645 :param email_on_retry: Indicates whether email alerts should be sent when a
646 task is retried (deprecated)
647 :param email_on_failure: Indicates whether email alerts should be sent when
648 a task failed (deprecated)
649 :param retries: the number of retries that should be performed before
650 failing the task
651 :param retry_delay: delay between retries, can be set as ``timedelta`` or
652 ``float`` seconds, which will be converted into ``timedelta``,
653 the default is ``timedelta(seconds=300)``.
654 :param retry_exponential_backoff: multiplier for exponential backoff between retries.
655 Set to 0 to disable (constant delay). Set to 2.0 for standard exponential backoff
656 (delay doubles with each retry). For example, with retry_delay=4min and
657 retry_exponential_backoff=5, retries occur after 4min, 20min, 100min, etc.
658 :param max_retry_delay: maximum delay interval between retries, can be set as
659 ``timedelta`` or ``float`` seconds, which will be converted into ``timedelta``.
660 :param start_date: The ``start_date`` for the task, determines
661 the ``logical_date`` for the first task instance. The best practice
662 is to have the start_date rounded
663 to your Dag's ``schedule_interval``. Daily jobs have their start_date
664 some day at 00:00:00, hourly jobs have their start_date at 00:00
665 of a specific hour. Note that Airflow simply looks at the latest
666 ``logical_date`` and adds the ``schedule_interval`` to determine
667 the next ``logical_date``. It is also very important
668 to note that different tasks' dependencies
669 need to line up in time. If task A depends on task B and their
670 start_date are offset in a way that their logical_date don't line
671 up, A's dependencies will never be met. If you are looking to delay
672 a task, for example running a daily task at 2AM, look into the
673 ``TimeSensor`` and ``TimeDeltaSensor``. We advise against using
674 dynamic ``start_date`` and recommend using fixed ones. Read the
675 FAQ entry about start_date for more information.
676 :param end_date: if specified, the scheduler won't go beyond this date
677 :param depends_on_past: when set to true, task instances will run
678 sequentially and only if the previous instance has succeeded or has been skipped.
679 The task instance for the start_date is allowed to run.
680 :param wait_for_past_depends_before_skipping: when set to true, if the task instance
681 should be marked as skipped, and depends_on_past is true, the ti will stay on None state
682 waiting the task of the previous run
683 :param wait_for_downstream: when set to true, an instance of task
684 X will wait for tasks immediately downstream of the previous instance
685 of task X to finish successfully or be skipped before it runs. This is useful if the
686 different instances of a task X alter the same asset, and this asset
687 is used by tasks downstream of task X. Note that depends_on_past
688 is forced to True wherever wait_for_downstream is used. Also note that
689 only tasks *immediately* downstream of the previous task instance are waited
690 for; the statuses of any tasks further downstream are ignored.
691 :param dag: a reference to the dag the task is attached to (if any)
692 :param priority_weight: priority weight of this task against other task.
693 This allows the executor to trigger higher priority tasks before
694 others when things get backed up. Set priority_weight as a higher
695 number for more important tasks.
696 As not all database engines support 64-bit integers, values are capped with 32-bit.
697 Valid range is from -2,147,483,648 to 2,147,483,647.
698 :param weight_rule: weighting method used for the effective total
699 priority weight of the task. Options are:
700 ``{ downstream | upstream | absolute }`` default is ``downstream``
701 When set to ``downstream`` the effective weight of the task is the
702 aggregate sum of all downstream descendants. As a result, upstream
703 tasks will have higher weight and will be scheduled more aggressively
704 when using positive weight values. This is useful when you have
705 multiple dag run instances and desire to have all upstream tasks to
706 complete for all runs before each dag can continue processing
707 downstream tasks. When set to ``upstream`` the effective weight is the
708 aggregate sum of all upstream ancestors. This is the opposite where
709 downstream tasks have higher weight and will be scheduled more
710 aggressively when using positive weight values. This is useful when you
711 have multiple dag run instances and prefer to have each dag complete
712 before starting upstream tasks of other dags. When set to
713 ``absolute``, the effective weight is the exact ``priority_weight``
714 specified without additional weighting. You may want to do this when
715 you know exactly what priority weight each task should have.
716 Additionally, when set to ``absolute``, there is bonus effect of
717 significantly speeding up the task creation process as for very large
718 Dags. Options can be set as string or using the constants defined in
719 the static class ``airflow.utils.WeightRule``.
720 Irrespective of the weight rule, resulting priority values are capped with 32-bit.
721 |experimental|
722 Since 2.9.0, Airflow allows to define custom priority weight strategy,
723 by creating a subclass of
724 ``airflow.task.priority_strategy.PriorityWeightStrategy`` and registering
725 in a plugin, then providing the class path or the class instance via
726 ``weight_rule`` parameter. The custom priority weight strategy will be
727 used to calculate the effective total priority weight of the task instance.
728 :param queue: which queue to target when running this job. Not
729 all executors implement queue management, the CeleryExecutor
730 does support targeting specific queues.
731 :param pool: the slot pool this task should run in, slot pools are a
732 way to limit concurrency for certain tasks
733 :param pool_slots: the number of pool slots this task should use (>= 1)
734 Values less than 1 are not allowed.
735 :param sla: DEPRECATED - The SLA feature is removed in Airflow 3.0, to be replaced with a
736 new implementation in Airflow >=3.1.
737 :param execution_timeout: max time allowed for the execution of
738 this task instance, if it goes beyond it will raise and fail.
739 :param on_failure_callback: a function or list of functions to be called when a task instance
740 of this task fails. a context dictionary is passed as a single
741 parameter to this function. Context contains references to related
742 objects to the task instance and is documented under the macros
743 section of the API.
744 :param on_execute_callback: much like the ``on_failure_callback`` except
745 that it is executed right before the task is executed.
746 :param on_retry_callback: much like the ``on_failure_callback`` except
747 that it is executed when retries occur.
748 :param on_success_callback: much like the ``on_failure_callback`` except
749 that it is executed when the task succeeds.
750 :param on_skipped_callback: much like the ``on_failure_callback`` except
751 that it is executed when skipped occur; this callback will be called only if AirflowSkipException get raised.
752 Explicitly it is NOT called if a task is not started to be executed because of a preceding branching
753 decision in the Dag or a trigger rule which causes execution to skip so that the task execution
754 is never scheduled.
755 :param pre_execute: a function to be called immediately before task
756 execution, receiving a context dictionary; raising an exception will
757 prevent the task from being executed.
758 :param post_execute: a function to be called immediately after task
759 execution, receiving a context dictionary and task result; raising an
760 exception will prevent the task from succeeding.
761 :param trigger_rule: defines the rule by which dependencies are applied
762 for the task to get triggered. Options are:
763 ``{ all_success | all_failed | all_done | all_skipped | one_success | one_done |
764 one_failed | none_failed | none_failed_min_one_success | none_skipped | always}``
765 default is ``all_success``. Options can be set as string or
766 using the constants defined in the static class
767 ``airflow.utils.TriggerRule``
768 :param resources: A map of resource parameter names (the argument names of the
769 Resources constructor) to their values.
770 :param run_as_user: unix username to impersonate while running the task
771 :param max_active_tis_per_dag: When set, a task will be able to limit the concurrent
772 runs across logical_dates.
773 :param max_active_tis_per_dagrun: When set, a task will be able to limit the concurrent
774 task instances per Dag run.
775 :param executor: Which executor to target when running this task. NOT YET SUPPORTED
776 :param executor_config: Additional task-level configuration parameters that are
777 interpreted by a specific executor. Parameters are namespaced by the name of
778 executor.
779
780 **Example**: to run this task in a specific docker container through
781 the KubernetesExecutor ::
782
783 MyOperator(..., executor_config={"KubernetesExecutor": {"image": "myCustomDockerImage"}})
784
785 :param do_xcom_push: if True, an XCom is pushed containing the Operator's
786 result
787 :param multiple_outputs: if True and do_xcom_push is True, pushes multiple XComs, one for each
788 key in the returned dictionary result. If False and do_xcom_push is True, pushes a single XCom.
789 :param task_group: The TaskGroup to which the task should belong. This is typically provided when not
790 using a TaskGroup as a context manager.
791 :param doc: Add documentation or notes to your Task objects that is visible in
792 Task Instance details View in the Webserver
793 :param doc_md: Add documentation (in Markdown format) or notes to your Task objects
794 that is visible in Task Instance details View in the Webserver
795 :param doc_rst: Add documentation (in RST format) or notes to your Task objects
796 that is visible in Task Instance details View in the Webserver
797 :param doc_json: Add documentation (in JSON format) or notes to your Task objects
798 that is visible in Task Instance details View in the Webserver
799 :param doc_yaml: Add documentation (in YAML format) or notes to your Task objects
800 that is visible in Task Instance details View in the Webserver
801 :param task_display_name: The display name of the task which appears on the UI.
802 :param logger_name: Name of the logger used by the Operator to emit logs.
803 If set to `None` (default), the logger name will fall back to
804 `airflow.task.operators.{class.__module__}.{class.__name__}` (e.g. HttpOperator will have
805 *airflow.task.operators.airflow.providers.http.operators.http.HttpOperator* as logger).
806 :param allow_nested_operators: if True, when an operator is executed within another one a warning message
807 will be logged. If False, then an exception will be raised if the operator is badly used (e.g. nested
808 within another one). In future releases of Airflow this parameter will be removed and an exception
809 will always be thrown when operators are nested within each other (default is True).
810
811 **Example**: example of a bad operator mixin usage::
812
813 @task(provide_context=True)
814 def say_hello_world(**context):
815 hello_world_task = BashOperator(
816 task_id="hello_world_task",
817 bash_command="python -c \"print('Hello, world!')\"",
818 dag=dag,
819 )
820 hello_world_task.execute(context)
821 """
822
823 task_id: str
824 owner: str = DEFAULT_OWNER
825 email: str | Sequence[str] | None = None
826 email_on_retry: bool = DEFAULT_EMAIL_ON_RETRY
827 email_on_failure: bool = DEFAULT_EMAIL_ON_FAILURE
828 retries: int | None = DEFAULT_RETRIES
829 retry_delay: timedelta = DEFAULT_RETRY_DELAY
830 retry_exponential_backoff: float = 0
831 max_retry_delay: timedelta | float | None = None
832 start_date: datetime | None = None
833 end_date: datetime | None = None
834 depends_on_past: bool = False
835 ignore_first_depends_on_past: bool = DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST
836 wait_for_past_depends_before_skipping: bool = DEFAULT_WAIT_FOR_PAST_DEPENDS_BEFORE_SKIPPING
837 wait_for_downstream: bool = False
838
839 # At execution_time this becomes a normal dict
840 params: ParamsDict | dict = field(default_factory=ParamsDict)
841 default_args: dict | None = None
842 priority_weight: int = DEFAULT_PRIORITY_WEIGHT
843 weight_rule: PriorityWeightStrategy | str = field(default=DEFAULT_WEIGHT_RULE)
844 queue: str = DEFAULT_QUEUE
845 pool: str = DEFAULT_POOL_NAME
846 pool_slots: int = DEFAULT_POOL_SLOTS
847 execution_timeout: timedelta | None = DEFAULT_TASK_EXECUTION_TIMEOUT
848 on_execute_callback: Sequence[TaskStateChangeCallback] = ()
849 on_failure_callback: Sequence[TaskStateChangeCallback] = ()
850 on_success_callback: Sequence[TaskStateChangeCallback] = ()
851 on_retry_callback: Sequence[TaskStateChangeCallback] = ()
852 on_skipped_callback: Sequence[TaskStateChangeCallback] = ()
853 _pre_execute_hook: TaskPreExecuteHook | None = None
854 _post_execute_hook: TaskPostExecuteHook | None = None
855 trigger_rule: TriggerRule = DEFAULT_TRIGGER_RULE
856 resources: dict[str, Any] | None = None
857 run_as_user: str | None = None
858 task_concurrency: int | None = None
859 map_index_template: str | None = None
860 max_active_tis_per_dag: int | None = None
861 max_active_tis_per_dagrun: int | None = None
862 executor: str | None = None
863 executor_config: dict | None = None
864 do_xcom_push: bool = True
865 multiple_outputs: bool = False
866 inlets: list[Any] = field(default_factory=list)
867 outlets: list[Any] = field(default_factory=list)
868 task_group: TaskGroup | None = None
869 doc: str | None = None
870 doc_md: str | None = None
871 doc_json: str | None = None
872 doc_yaml: str | None = None
873 doc_rst: str | None = None
874 _task_display_name: str | None = None
875 logger_name: str | None = None
876 allow_nested_operators: bool = True
877
878 is_setup: bool = False
879 is_teardown: bool = False
880
881 # TODO: Task-SDK: Make these ClassVar[]?
882 template_fields: Collection[str] = ()
883 template_ext: Sequence[str] = ()
884
885 template_fields_renderers: ClassVar[dict[str, str]] = {}
886
887 operator_extra_links: Collection[BaseOperatorLink] = ()
888
889 # Defines the color in the UI
890 ui_color: str = "#fff"
891 ui_fgcolor: str = "#000"
892
893 partial: Callable[..., OperatorPartial] = _PartialDescriptor() # type: ignore
894
895 _dag: DAG | None = field(init=False, default=None)
896
897 # Make this optional so the type matches the one define in LoggingMixin
898 _log_config_logger_name: str | None = field(default="airflow.task.operators", init=False)
899 _logger_name: str | None = None
900
901 # The _serialized_fields are lazily loaded when get_serialized_fields() method is called
902 __serialized_fields: ClassVar[frozenset[str] | None] = None
903
904 _comps: ClassVar[set[str]] = {
905 "task_id",
906 "dag_id",
907 "owner",
908 "email",
909 "email_on_retry",
910 "retry_delay",
911 "retry_exponential_backoff",
912 "max_retry_delay",
913 "start_date",
914 "end_date",
915 "depends_on_past",
916 "wait_for_downstream",
917 "priority_weight",
918 "execution_timeout",
919 "has_on_execute_callback",
920 "has_on_failure_callback",
921 "has_on_success_callback",
922 "has_on_retry_callback",
923 "has_on_skipped_callback",
924 "do_xcom_push",
925 "multiple_outputs",
926 "allow_nested_operators",
927 "executor",
928 }
929
930 # If True, the Rendered Template fields will be overwritten in DB after execution
931 # This is useful for Taskflow decorators that modify the template fields during execution like
932 # @task.bash decorator.
933 overwrite_rtif_after_execution: bool = False
934
935 # If True then the class constructor was called
936 __instantiated: bool = False
937 # List of args as passed to `init()`, after apply_defaults() has been updated. Used to "recreate" the task
938 # when mapping
939 # Set via the metaclass
940 __init_kwargs: dict[str, Any] = field(init=False)
941
942 # Set to True before calling execute method
943 _lock_for_execution: bool = False
944
945 # Set to True for an operator instantiated by a mapped operator.
946 __from_mapped: bool = False
947
948 start_trigger_args: StartTriggerArgs | None = None
949 start_from_trigger: bool = False
950
951 # base list which includes all the attrs that don't need deep copy.
952 _base_operator_shallow_copy_attrs: Final[tuple[str, ...]] = (
953 "user_defined_macros",
954 "user_defined_filters",
955 "params",
956 )
957
958 # each operator should override this class attr for shallow copy attrs.
959 shallow_copy_attrs: Sequence[str] = ()
960
961 def __setattr__(self: BaseOperator, key: str, value: Any):
962 if converter := getattr(self, f"_convert_{key}", None):
963 value = converter(value)
964 super().__setattr__(key, value)
965 if self.__from_mapped or self._lock_for_execution:
966 return # Skip any custom behavior for validation and during execute.
967 if key in self.__init_kwargs:
968 self.__init_kwargs[key] = value
969 if self.__instantiated and key in self.template_fields:
970 # Resolve upstreams set by assigning an XComArg after initializing
971 # an operator, example:
972 # op = BashOperator()
973 # op.bash_command = "sleep 1"
974 self._set_xcomargs_dependency(key, value)
975
976 def __init__(
977 self,
978 *,
979 task_id: str,
980 owner: str = DEFAULT_OWNER,
981 email: str | Sequence[str] | None = None,
982 email_on_retry: bool = DEFAULT_EMAIL_ON_RETRY,
983 email_on_failure: bool = DEFAULT_EMAIL_ON_FAILURE,
984 retries: int | None = DEFAULT_RETRIES,
985 retry_delay: timedelta | float = DEFAULT_RETRY_DELAY,
986 retry_exponential_backoff: float = 0,
987 max_retry_delay: timedelta | float | None = None,
988 start_date: datetime | None = None,
989 end_date: datetime | None = None,
990 depends_on_past: bool = False,
991 ignore_first_depends_on_past: bool = DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST,
992 wait_for_past_depends_before_skipping: bool = DEFAULT_WAIT_FOR_PAST_DEPENDS_BEFORE_SKIPPING,
993 wait_for_downstream: bool = False,
994 dag: DAG | None = None,
995 params: collections.abc.MutableMapping[str, Any] | None = None,
996 default_args: dict | None = None,
997 priority_weight: int = DEFAULT_PRIORITY_WEIGHT,
998 weight_rule: str | PriorityWeightStrategy = DEFAULT_WEIGHT_RULE,
999 queue: str = DEFAULT_QUEUE,
1000 pool: str | None = None,
1001 pool_slots: int = DEFAULT_POOL_SLOTS,
1002 sla: timedelta | None = None,
1003 execution_timeout: timedelta | None = DEFAULT_TASK_EXECUTION_TIMEOUT,
1004 on_execute_callback: None | TaskStateChangeCallback | Collection[TaskStateChangeCallback] = None,
1005 on_failure_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = None,
1006 on_success_callback: None | TaskStateChangeCallback | Collection[TaskStateChangeCallback] = None,
1007 on_retry_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = None,
1008 on_skipped_callback: None | TaskStateChangeCallback | Collection[TaskStateChangeCallback] = None,
1009 pre_execute: TaskPreExecuteHook | None = None,
1010 post_execute: TaskPostExecuteHook | None = None,
1011 trigger_rule: str = DEFAULT_TRIGGER_RULE,
1012 resources: dict[str, Any] | None = None,
1013 run_as_user: str | None = None,
1014 map_index_template: str | None = None,
1015 max_active_tis_per_dag: int | None = None,
1016 max_active_tis_per_dagrun: int | None = None,
1017 executor: str | None = None,
1018 executor_config: dict | None = None,
1019 do_xcom_push: bool = True,
1020 multiple_outputs: bool = False,
1021 inlets: Any | None = None,
1022 outlets: Any | None = None,
1023 task_group: TaskGroup | None = None,
1024 doc: str | None = None,
1025 doc_md: str | None = None,
1026 doc_json: str | None = None,
1027 doc_yaml: str | None = None,
1028 doc_rst: str | None = None,
1029 task_display_name: str | None = None,
1030 logger_name: str | None = None,
1031 allow_nested_operators: bool = True,
1032 **kwargs: Any,
1033 ):
1034 # Note: Metaclass handles passing in the Dag/TaskGroup from active context manager, if any
1035
1036 # Only apply task_group prefix if this operator was not created from a mapped operator
1037 # Mapped operators already have the prefix applied during their creation
1038 if task_group and not self.__from_mapped:
1039 self.task_id = task_group.child_id(task_id)
1040 task_group.add(self)
1041 else:
1042 self.task_id = task_id
1043
1044 super().__init__()
1045 self.task_group = task_group
1046
1047 kwargs.pop("_airflow_mapped_validation_only", None)
1048 if kwargs:
1049 raise TypeError(
1050 f"Invalid arguments were passed to {self.__class__.__name__} (task_id: {task_id}). "
1051 f"Invalid arguments were:\n**kwargs: {redact(kwargs)}",
1052 )
1053 validate_key(self.task_id)
1054
1055 self.owner = owner
1056 self.email = email
1057 self.email_on_retry = email_on_retry
1058 self.email_on_failure = email_on_failure
1059
1060 if email is not None:
1061 warnings.warn(
1062 "Setting email on a task is deprecated; please migrate to SmtpNotifier.",
1063 RemovedInAirflow4Warning,
1064 stacklevel=2,
1065 )
1066 if email and email_on_retry is not None:
1067 warnings.warn(
1068 "Setting email_on_retry on a task is deprecated; please migrate to SmtpNotifier.",
1069 RemovedInAirflow4Warning,
1070 stacklevel=2,
1071 )
1072 if email and email_on_failure is not None:
1073 warnings.warn(
1074 "Setting email_on_failure on a task is deprecated; please migrate to SmtpNotifier.",
1075 RemovedInAirflow4Warning,
1076 stacklevel=2,
1077 )
1078
1079 if execution_timeout is not None and not isinstance(execution_timeout, timedelta):
1080 raise ValueError(
1081 f"execution_timeout must be timedelta object but passed as type: {type(execution_timeout)}"
1082 )
1083 self.execution_timeout = execution_timeout
1084
1085 self.on_execute_callback = _collect_from_input(on_execute_callback)
1086 self.on_failure_callback = _collect_from_input(on_failure_callback)
1087 self.on_success_callback = _collect_from_input(on_success_callback)
1088 self.on_retry_callback = _collect_from_input(on_retry_callback)
1089 self.on_skipped_callback = _collect_from_input(on_skipped_callback)
1090 self._pre_execute_hook = pre_execute
1091 self._post_execute_hook = post_execute
1092
1093 self.start_date = timezone.convert_to_utc(start_date)
1094 self.end_date = timezone.convert_to_utc(end_date)
1095 self.executor = executor
1096 self.executor_config = executor_config or {}
1097 self.run_as_user = run_as_user
1098 # TODO:
1099 # self.retries = parse_retries(retries)
1100 self.retries = retries
1101 self.queue = queue
1102 self.pool = DEFAULT_POOL_NAME if pool is None else pool
1103 self.pool_slots = pool_slots
1104 if self.pool_slots < 1:
1105 dag_str = f" in dag {dag.dag_id}" if dag else ""
1106 raise ValueError(f"pool slots for {self.task_id}{dag_str} cannot be less than 1")
1107 if sla is not None:
1108 warnings.warn(
1109 "The SLA feature is removed in Airflow 3.0, replaced with Deadline Alerts in >=3.1",
1110 stacklevel=2,
1111 )
1112
1113 try:
1114 TriggerRule(trigger_rule)
1115 except ValueError:
1116 raise ValueError(
1117 f"The trigger_rule must be one of {[rule.value for rule in TriggerRule]},"
1118 f"'{dag.dag_id if dag else ''}.{task_id}'; received '{trigger_rule}'."
1119 )
1120
1121 self.trigger_rule: TriggerRule = TriggerRule(trigger_rule)
1122
1123 self.depends_on_past: bool = depends_on_past
1124 self.ignore_first_depends_on_past: bool = ignore_first_depends_on_past
1125 self.wait_for_past_depends_before_skipping: bool = wait_for_past_depends_before_skipping
1126 self.wait_for_downstream: bool = wait_for_downstream
1127 if wait_for_downstream:
1128 self.depends_on_past = True
1129
1130 # Converted by setattr
1131 self.retry_delay = retry_delay # type: ignore[assignment]
1132 self.retry_exponential_backoff = retry_exponential_backoff
1133 if max_retry_delay is not None:
1134 self.max_retry_delay = max_retry_delay
1135
1136 self.resources = resources
1137
1138 self.params = ParamsDict(params)
1139
1140 self.priority_weight = priority_weight
1141 self.weight_rule = weight_rule
1142
1143 self.max_active_tis_per_dag: int | None = max_active_tis_per_dag
1144 self.max_active_tis_per_dagrun: int | None = max_active_tis_per_dagrun
1145 self.do_xcom_push: bool = do_xcom_push
1146 self.map_index_template: str | None = map_index_template
1147 self.multiple_outputs: bool = multiple_outputs
1148
1149 self.doc_md = doc_md
1150 self.doc_json = doc_json
1151 self.doc_yaml = doc_yaml
1152 self.doc_rst = doc_rst
1153 self.doc = doc
1154
1155 self._task_display_name = task_display_name
1156
1157 self.allow_nested_operators = allow_nested_operators
1158
1159 self._logger_name = logger_name
1160
1161 # Lineage
1162 self.inlets = _collect_from_input(inlets)
1163 self.outlets = _collect_from_input(outlets)
1164
1165 if isinstance(self.template_fields, str):
1166 warnings.warn(
1167 f"The `template_fields` value for {self.task_type} is a string "
1168 "but should be a list or tuple of string. Wrapping it in a list for execution. "
1169 f"Please update {self.task_type} accordingly.",
1170 UserWarning,
1171 stacklevel=2,
1172 )
1173 self.template_fields = [self.template_fields]
1174
1175 self.is_setup = False
1176 self.is_teardown = False
1177
1178 if SetupTeardownContext.active:
1179 SetupTeardownContext.update_context_map(self)
1180
1181 # We set self.dag right at the end as `_convert_dag` calls `dag.add_task` for us, and we need all the
1182 # other properties to be set at that point
1183 if dag is not None:
1184 self.dag = dag
1185
1186 validate_instance_args(self, BASEOPERATOR_ARGS_EXPECTED_TYPES)
1187
1188 # Ensure priority_weight is within the valid range
1189 self.priority_weight = db_safe_priority(self.priority_weight)
1190
1191 def __eq__(self, other):
1192 if type(self) is type(other):
1193 # Use getattr() instead of __dict__ as __dict__ doesn't return
1194 # correct values for properties.
1195 return all(getattr(self, c, None) == getattr(other, c, None) for c in self._comps)
1196 return False
1197
1198 def __ne__(self, other):
1199 return not self == other
1200
1201 def __hash__(self):
1202 hash_components = [type(self)]
1203 for component in self._comps:
1204 val = getattr(self, component, None)
1205 try:
1206 hash(val)
1207 hash_components.append(val)
1208 except TypeError:
1209 hash_components.append(repr(val))
1210 return hash(tuple(hash_components))
1211
1212 # /Composing Operators ---------------------------------------------
1213
1214 def __gt__(self, other):
1215 """
1216 Return [Operator] > [Outlet].
1217
1218 If other is an attr annotated object it is set as an outlet of this Operator.
1219 """
1220 if not isinstance(other, Iterable):
1221 other = [other]
1222
1223 for obj in other:
1224 if not attrs.has(obj):
1225 raise TypeError(f"Left hand side ({obj}) is not an outlet")
1226 self.add_outlets(other)
1227
1228 return self
1229
1230 def __lt__(self, other):
1231 """
1232 Return [Inlet] > [Operator] or [Operator] < [Inlet].
1233
1234 If other is an attr annotated object it is set as an inlet to this operator.
1235 """
1236 if not isinstance(other, Iterable):
1237 other = [other]
1238
1239 for obj in other:
1240 if not attrs.has(obj):
1241 raise TypeError(f"{obj} cannot be an inlet")
1242 self.add_inlets(other)
1243
1244 return self
1245
1246 def __deepcopy__(self, memo: dict[int, Any]):
1247 # Hack sorting double chained task lists by task_id to avoid hitting
1248 # max_depth on deepcopy operations.
1249 sys.setrecursionlimit(5000) # TODO fix this in a better way
1250
1251 cls = self.__class__
1252 result = cls.__new__(cls)
1253 memo[id(self)] = result
1254
1255 shallow_copy = tuple(cls.shallow_copy_attrs) + cls._base_operator_shallow_copy_attrs
1256
1257 for k, v_org in self.__dict__.items():
1258 if k not in shallow_copy:
1259 v = copy.deepcopy(v_org, memo)
1260 else:
1261 v = copy.copy(v_org)
1262
1263 # Bypass any setters, and set it on the object directly. This works since we are cloning ourself so
1264 # we know the type is already fine
1265 result.__dict__[k] = v
1266 return result
1267
1268 def __getstate__(self):
1269 state = dict(self.__dict__)
1270 if "_log" in state:
1271 del state["_log"]
1272
1273 return state
1274
1275 def __setstate__(self, state):
1276 self.__dict__ = state
1277
1278 def add_inlets(self, inlets: Iterable[Any]):
1279 """Set inlets to this operator."""
1280 self.inlets.extend(inlets)
1281
1282 def add_outlets(self, outlets: Iterable[Any]):
1283 """Define the outlets of this operator."""
1284 self.outlets.extend(outlets)
1285
1286 def get_dag(self) -> DAG | None:
1287 return self._dag
1288
1289 @property
1290 def dag(self) -> DAG:
1291 """Returns the Operator's Dag if set, otherwise raises an error."""
1292 if dag := self._dag:
1293 return dag
1294 raise RuntimeError(f"Operator {self} has not been assigned to a Dag yet")
1295
1296 @dag.setter
1297 def dag(self, dag: DAG | None) -> None:
1298 """Operators can be assigned to one Dag, one time. Repeat assignments to that same Dag are ok."""
1299 self._dag = dag
1300
1301 def _convert__dag(self, dag: DAG | None) -> DAG | None:
1302 # Called automatically by __setattr__ method
1303 from airflow.sdk.definitions.dag import DAG
1304
1305 if dag is None:
1306 return dag
1307
1308 if not isinstance(dag, DAG):
1309 raise TypeError(f"Expected dag; received {dag.__class__.__name__}")
1310 if self._dag is not None and self._dag is not dag:
1311 raise ValueError(f"The dag assigned to {self} can not be changed.")
1312
1313 if self.__from_mapped:
1314 pass # Don't add to dag -- the mapped task takes the place.
1315 elif dag.task_dict.get(self.task_id) is not self:
1316 dag.add_task(self)
1317 return dag
1318
1319 @staticmethod
1320 def _convert_retries(retries: Any) -> int | None:
1321 if retries is None:
1322 return 0
1323 if type(retries) == int: # noqa: E721
1324 return retries
1325 try:
1326 parsed_retries = int(retries)
1327 except (TypeError, ValueError):
1328 raise TypeError(f"'retries' type must be int, not {type(retries).__name__}")
1329 return parsed_retries
1330
1331 @staticmethod
1332 def _convert_timedelta(value: float | timedelta | None) -> timedelta | None:
1333 if value is None or isinstance(value, timedelta):
1334 return value
1335 return timedelta(seconds=value)
1336
1337 _convert_retry_delay = _convert_timedelta
1338 _convert_max_retry_delay = _convert_timedelta
1339
1340 @staticmethod
1341 def _convert_resources(resources: dict[str, Any] | None) -> Resources | None:
1342 if resources is None:
1343 return None
1344
1345 from airflow.sdk.definitions.operator_resources import Resources
1346
1347 if isinstance(resources, Resources):
1348 return resources
1349
1350 return Resources(**resources)
1351
1352 def _convert_is_setup(self, value: bool) -> bool:
1353 """
1354 Setter for is_setup property.
1355
1356 :meta private:
1357 """
1358 if self.is_teardown and value:
1359 raise ValueError(f"Cannot mark task '{self.task_id}' as setup; task is already a teardown.")
1360 return value
1361
1362 def _convert_is_teardown(self, value: bool) -> bool:
1363 if self.is_setup and value:
1364 raise ValueError(f"Cannot mark task '{self.task_id}' as teardown; task is already a setup.")
1365 return value
1366
1367 @property
1368 def task_display_name(self) -> str:
1369 return self._task_display_name or self.task_id
1370
1371 def has_dag(self):
1372 """Return True if the Operator has been assigned to a Dag."""
1373 return self._dag is not None
1374
1375 def _set_xcomargs_dependencies(self) -> None:
1376 from airflow.sdk.definitions.xcom_arg import XComArg
1377
1378 for f in self.template_fields:
1379 arg = getattr(self, f, NOTSET)
1380 if arg is not NOTSET:
1381 XComArg.apply_upstream_relationship(self, arg)
1382
1383 def _set_xcomargs_dependency(self, field: str, newvalue: Any) -> None:
1384 """
1385 Resolve upstream dependencies of a task.
1386
1387 In this way passing an ``XComArg`` as value for a template field
1388 will result in creating upstream relation between two tasks.
1389
1390 **Example**: ::
1391
1392 with DAG(...):
1393 generate_content = GenerateContentOperator(task_id="generate_content")
1394 send_email = EmailOperator(..., html_content=generate_content.output)
1395
1396 # This is equivalent to
1397 with DAG(...):
1398 generate_content = GenerateContentOperator(task_id="generate_content")
1399 send_email = EmailOperator(
1400 ..., html_content="{{ task_instance.xcom_pull('generate_content') }}"
1401 )
1402 generate_content >> send_email
1403
1404 """
1405 from airflow.sdk.definitions.xcom_arg import XComArg
1406
1407 if field not in self.template_fields:
1408 return
1409 XComArg.apply_upstream_relationship(self, newvalue)
1410
1411 def on_kill(self) -> None:
1412 """
1413 Override this method to clean up subprocesses when a task instance gets killed.
1414
1415 Any use of the threading, subprocess or multiprocessing module within an
1416 operator needs to be cleaned up, or it will leave ghost processes behind.
1417 """
1418
1419 def __repr__(self):
1420 return f"<Task({self.task_type}): {self.task_id}>"
1421
1422 @property
1423 def operator_class(self) -> type[BaseOperator]: # type: ignore[override]
1424 return self.__class__
1425
1426 @property
1427 def task_type(self) -> str:
1428 """@property: type of the task."""
1429 return self.__class__.__name__
1430
1431 @property
1432 def operator_name(self) -> str:
1433 """@property: use a more friendly display name for the operator, if set."""
1434 try:
1435 return self.custom_operator_name # type: ignore
1436 except AttributeError:
1437 return self.task_type
1438
1439 @property
1440 def roots(self) -> list[BaseOperator]:
1441 """Required by DAGNode."""
1442 return [self]
1443
1444 @property
1445 def leaves(self) -> list[BaseOperator]:
1446 """Required by DAGNode."""
1447 return [self]
1448
1449 @property
1450 def output(self) -> XComArg:
1451 """Returns reference to XCom pushed by current operator."""
1452 from airflow.sdk.definitions.xcom_arg import XComArg
1453
1454 return XComArg(operator=self)
1455
1456 @classmethod
1457 def get_serialized_fields(cls):
1458 """Stringified Dags and operators contain exactly these fields."""
1459 if not cls.__serialized_fields:
1460 from airflow.sdk.definitions._internal.contextmanager import DagContext
1461
1462 # make sure the following "fake" task is not added to current active
1463 # dag in context, otherwise, it will result in
1464 # `RuntimeError: dictionary changed size during iteration`
1465 # Exception in SerializedDAG.serialize_dag() call.
1466 DagContext.push(None)
1467 cls.__serialized_fields = frozenset(
1468 vars(BaseOperator(task_id="test")).keys()
1469 - {
1470 "upstream_task_ids",
1471 "default_args",
1472 "dag",
1473 "_dag",
1474 "label",
1475 "_BaseOperator__instantiated",
1476 "_BaseOperator__init_kwargs",
1477 "_BaseOperator__from_mapped",
1478 "on_failure_fail_dagrun",
1479 "task_group",
1480 "_task_type",
1481 "operator_extra_links",
1482 "on_execute_callback",
1483 "on_failure_callback",
1484 "on_success_callback",
1485 "on_retry_callback",
1486 "on_skipped_callback",
1487 }
1488 | { # Class level defaults, or `@property` need to be added to this list
1489 "start_date",
1490 "end_date",
1491 "task_type",
1492 "ui_color",
1493 "ui_fgcolor",
1494 "template_ext",
1495 "template_fields",
1496 "template_fields_renderers",
1497 "params",
1498 "is_setup",
1499 "is_teardown",
1500 "on_failure_fail_dagrun",
1501 "map_index_template",
1502 "start_trigger_args",
1503 "_needs_expansion",
1504 "start_from_trigger",
1505 "max_retry_delay",
1506 "has_on_execute_callback",
1507 "has_on_failure_callback",
1508 "has_on_success_callback",
1509 "has_on_retry_callback",
1510 "has_on_skipped_callback",
1511 }
1512 )
1513 DagContext.pop()
1514
1515 return cls.__serialized_fields
1516
1517 def prepare_for_execution(self) -> Self:
1518 """Lock task for execution to disable custom action in ``__setattr__`` and return a copy."""
1519 other = copy.copy(self)
1520 other._lock_for_execution = True
1521 return other
1522
1523 def serialize_for_task_group(self) -> tuple[DagAttributeTypes, Any]:
1524 """Serialize; required by DAGNode."""
1525 from airflow.serialization.enums import DagAttributeTypes
1526
1527 return DagAttributeTypes.OP, self.task_id
1528
1529 def unmap(self, resolve: None | Mapping[str, Any]) -> Self:
1530 """
1531 Get the "normal" operator from the current operator.
1532
1533 Since a BaseOperator is not mapped to begin with, this simply returns
1534 the original operator.
1535
1536 :meta private:
1537 """
1538 return self
1539
1540 def expand_start_trigger_args(self, *, context: Context) -> StartTriggerArgs | None:
1541 """
1542 Get the start_trigger_args value of the current abstract operator.
1543
1544 Since a BaseOperator is not mapped to begin with, this simply returns
1545 the original value of start_trigger_args.
1546
1547 :meta private:
1548 """
1549 return self.start_trigger_args
1550
1551 def render_template_fields(
1552 self,
1553 context: Context,
1554 jinja_env: jinja2.Environment | None = None,
1555 ) -> None:
1556 """
1557 Template all attributes listed in *self.template_fields*.
1558
1559 This mutates the attributes in-place and is irreversible.
1560
1561 :param context: Context dict with values to apply on content.
1562 :param jinja_env: Jinja's environment to use for rendering.
1563 """
1564 if not jinja_env:
1565 jinja_env = self.get_template_env()
1566 self._do_render_template_fields(self, self.template_fields, context, jinja_env, set())
1567
1568 def pre_execute(self, context: Any):
1569 """Execute right before self.execute() is called."""
1570
1571 def execute(self, context: Context) -> Any:
1572 """
1573 Derive when creating an operator.
1574
1575 The main method to execute the task. Context is the same dictionary used
1576 as when rendering jinja templates.
1577
1578 Refer to get_template_context for more context.
1579 """
1580 raise NotImplementedError()
1581
1582 def post_execute(self, context: Any, result: Any = None):
1583 """
1584 Execute right after self.execute() is called.
1585
1586 It is passed the execution context and any results returned by the operator.
1587 """
1588
1589 def defer(
1590 self,
1591 *,
1592 trigger: BaseTrigger,
1593 method_name: str,
1594 kwargs: dict[str, Any] | None = None,
1595 timeout: timedelta | int | float | None = None,
1596 ) -> NoReturn:
1597 """
1598 Mark this Operator "deferred", suspending its execution until the provided trigger fires an event.
1599
1600 This is achieved by raising a special exception (TaskDeferred)
1601 which is caught in the main _execute_task wrapper. Triggers can send execution back to task or end
1602 the task instance directly. If the trigger will end the task instance itself, ``method_name`` should
1603 be None; otherwise, provide the name of the method that should be used when resuming execution in
1604 the task.
1605 """
1606 from airflow.sdk.exceptions import TaskDeferred
1607
1608 raise TaskDeferred(trigger=trigger, method_name=method_name, kwargs=kwargs, timeout=timeout)
1609
1610 def resume_execution(self, next_method: str, next_kwargs: dict[str, Any] | None, context: Context):
1611 """Entrypoint method called by the Task Runner (instead of execute) when this task is resumed."""
1612 from airflow.sdk.exceptions import TaskDeferralError, TaskDeferralTimeout
1613
1614 if next_kwargs is None:
1615 next_kwargs = {}
1616 # __fail__ is a special signal value for next_method that indicates
1617 # this task was scheduled specifically to fail.
1618
1619 if next_method == TRIGGER_FAIL_REPR:
1620 next_kwargs = next_kwargs or {}
1621 traceback = next_kwargs.get("traceback")
1622 if traceback is not None:
1623 self.log.error("Trigger failed:\n%s", "\n".join(traceback))
1624 if (error := next_kwargs.get("error", "Unknown")) == TriggerFailureReason.TRIGGER_TIMEOUT:
1625 raise TaskDeferralTimeout(error)
1626 raise TaskDeferralError(error)
1627 # Grab the callable off the Operator/Task and add in any kwargs
1628 execute_callable = getattr(self, next_method)
1629 return execute_callable(context, **next_kwargs)
1630
1631 def dry_run(self) -> None:
1632 """Perform dry run for the operator - just render template fields."""
1633 self.log.info("Dry run")
1634 for f in self.template_fields:
1635 try:
1636 content = getattr(self, f)
1637 except AttributeError:
1638 raise AttributeError(
1639 f"{f!r} is configured as a template field "
1640 f"but {self.task_type} does not have this attribute."
1641 )
1642
1643 if content and isinstance(content, str):
1644 self.log.info("Rendering template for %s", f)
1645 self.log.info(content)
1646
1647 @property
1648 def has_on_execute_callback(self) -> bool:
1649 """Return True if the task has execute callbacks."""
1650 return bool(self.on_execute_callback)
1651
1652 @property
1653 def has_on_failure_callback(self) -> bool:
1654 """Return True if the task has failure callbacks."""
1655 return bool(self.on_failure_callback)
1656
1657 @property
1658 def has_on_success_callback(self) -> bool:
1659 """Return True if the task has success callbacks."""
1660 return bool(self.on_success_callback)
1661
1662 @property
1663 def has_on_retry_callback(self) -> bool:
1664 """Return True if the task has retry callbacks."""
1665 return bool(self.on_retry_callback)
1666
1667 @property
1668 def has_on_skipped_callback(self) -> bool:
1669 """Return True if the task has skipped callbacks."""
1670 return bool(self.on_skipped_callback)
1671
1672
1673def chain(*tasks: DependencyMixin | Sequence[DependencyMixin]) -> None:
1674 r"""
1675 Given a number of tasks, builds a dependency chain.
1676
1677 This function accepts values of BaseOperator (aka tasks), EdgeModifiers (aka Labels), XComArg, TaskGroups,
1678 or lists containing any mix of these types (or a mix in the same list). If you want to chain between two
1679 lists you must ensure they have the same length.
1680
1681 Using classic operators/sensors:
1682
1683 .. code-block:: python
1684
1685 chain(t1, [t2, t3], [t4, t5], t6)
1686
1687 is equivalent to::
1688
1689 / -> t2 -> t4 \
1690 t1 -> t6
1691 \ -> t3 -> t5 /
1692
1693 .. code-block:: python
1694
1695 t1.set_downstream(t2)
1696 t1.set_downstream(t3)
1697 t2.set_downstream(t4)
1698 t3.set_downstream(t5)
1699 t4.set_downstream(t6)
1700 t5.set_downstream(t6)
1701
1702 Using task-decorated functions aka XComArgs:
1703
1704 .. code-block:: python
1705
1706 chain(x1(), [x2(), x3()], [x4(), x5()], x6())
1707
1708 is equivalent to::
1709
1710 / -> x2 -> x4 \
1711 x1 -> x6
1712 \ -> x3 -> x5 /
1713
1714 .. code-block:: python
1715
1716 x1 = x1()
1717 x2 = x2()
1718 x3 = x3()
1719 x4 = x4()
1720 x5 = x5()
1721 x6 = x6()
1722 x1.set_downstream(x2)
1723 x1.set_downstream(x3)
1724 x2.set_downstream(x4)
1725 x3.set_downstream(x5)
1726 x4.set_downstream(x6)
1727 x5.set_downstream(x6)
1728
1729 Using TaskGroups:
1730
1731 .. code-block:: python
1732
1733 chain(t1, task_group1, task_group2, t2)
1734
1735 t1.set_downstream(task_group1)
1736 task_group1.set_downstream(task_group2)
1737 task_group2.set_downstream(t2)
1738
1739
1740 It is also possible to mix between classic operator/sensor, EdgeModifiers, XComArg, and TaskGroups:
1741
1742 .. code-block:: python
1743
1744 chain(t1, [Label("branch one"), Label("branch two")], [x1(), x2()], task_group1, x3())
1745
1746 is equivalent to::
1747
1748 / "branch one" -> x1 \
1749 t1 -> task_group1 -> x3
1750 \ "branch two" -> x2 /
1751
1752 .. code-block:: python
1753
1754 x1 = x1()
1755 x2 = x2()
1756 x3 = x3()
1757 label1 = Label("branch one")
1758 label2 = Label("branch two")
1759 t1.set_downstream(label1)
1760 label1.set_downstream(x1)
1761 t2.set_downstream(label2)
1762 label2.set_downstream(x2)
1763 x1.set_downstream(task_group1)
1764 x2.set_downstream(task_group1)
1765 task_group1.set_downstream(x3)
1766
1767 # or
1768
1769 x1 = x1()
1770 x2 = x2()
1771 x3 = x3()
1772 t1.set_downstream(x1, edge_modifier=Label("branch one"))
1773 t1.set_downstream(x2, edge_modifier=Label("branch two"))
1774 x1.set_downstream(task_group1)
1775 x2.set_downstream(task_group1)
1776 task_group1.set_downstream(x3)
1777
1778
1779 :param tasks: Individual and/or list of tasks, EdgeModifiers, XComArgs, or TaskGroups to set dependencies
1780 """
1781 for up_task, down_task in zip(tasks, tasks[1:]):
1782 if isinstance(up_task, DependencyMixin):
1783 up_task.set_downstream(down_task)
1784 continue
1785 if isinstance(down_task, DependencyMixin):
1786 down_task.set_upstream(up_task)
1787 continue
1788 if not isinstance(up_task, Sequence) or not isinstance(down_task, Sequence):
1789 raise TypeError(f"Chain not supported between instances of {type(up_task)} and {type(down_task)}")
1790 up_task_list = up_task
1791 down_task_list = down_task
1792 if len(up_task_list) != len(down_task_list):
1793 raise ValueError(
1794 f"Chain not supported for different length Iterable. "
1795 f"Got {len(up_task_list)} and {len(down_task_list)}."
1796 )
1797 for up_t, down_t in zip(up_task_list, down_task_list):
1798 up_t.set_downstream(down_t)
1799
1800
1801def cross_downstream(
1802 from_tasks: Sequence[DependencyMixin],
1803 to_tasks: DependencyMixin | Sequence[DependencyMixin],
1804):
1805 r"""
1806 Set downstream dependencies for all tasks in from_tasks to all tasks in to_tasks.
1807
1808 Using classic operators/sensors:
1809
1810 .. code-block:: python
1811
1812 cross_downstream(from_tasks=[t1, t2, t3], to_tasks=[t4, t5, t6])
1813
1814 is equivalent to::
1815
1816 t1 ---> t4
1817 \ /
1818 t2 -X -> t5
1819 / \
1820 t3 ---> t6
1821
1822 .. code-block:: python
1823
1824 t1.set_downstream(t4)
1825 t1.set_downstream(t5)
1826 t1.set_downstream(t6)
1827 t2.set_downstream(t4)
1828 t2.set_downstream(t5)
1829 t2.set_downstream(t6)
1830 t3.set_downstream(t4)
1831 t3.set_downstream(t5)
1832 t3.set_downstream(t6)
1833
1834 Using task-decorated functions aka XComArgs:
1835
1836 .. code-block:: python
1837
1838 cross_downstream(from_tasks=[x1(), x2(), x3()], to_tasks=[x4(), x5(), x6()])
1839
1840 is equivalent to::
1841
1842 x1 ---> x4
1843 \ /
1844 x2 -X -> x5
1845 / \
1846 x3 ---> x6
1847
1848 .. code-block:: python
1849
1850 x1 = x1()
1851 x2 = x2()
1852 x3 = x3()
1853 x4 = x4()
1854 x5 = x5()
1855 x6 = x6()
1856 x1.set_downstream(x4)
1857 x1.set_downstream(x5)
1858 x1.set_downstream(x6)
1859 x2.set_downstream(x4)
1860 x2.set_downstream(x5)
1861 x2.set_downstream(x6)
1862 x3.set_downstream(x4)
1863 x3.set_downstream(x5)
1864 x3.set_downstream(x6)
1865
1866 It is also possible to mix between classic operator/sensor and XComArg tasks:
1867
1868 .. code-block:: python
1869
1870 cross_downstream(from_tasks=[t1, x2(), t3], to_tasks=[x1(), t2, x3()])
1871
1872 is equivalent to::
1873
1874 t1 ---> x1
1875 \ /
1876 x2 -X -> t2
1877 / \
1878 t3 ---> x3
1879
1880 .. code-block:: python
1881
1882 x1 = x1()
1883 x2 = x2()
1884 x3 = x3()
1885 t1.set_downstream(x1)
1886 t1.set_downstream(t2)
1887 t1.set_downstream(x3)
1888 x2.set_downstream(x1)
1889 x2.set_downstream(t2)
1890 x2.set_downstream(x3)
1891 t3.set_downstream(x1)
1892 t3.set_downstream(t2)
1893 t3.set_downstream(x3)
1894
1895 :param from_tasks: List of tasks or XComArgs to start from.
1896 :param to_tasks: List of tasks or XComArgs to set as downstream dependencies.
1897 """
1898 for task in from_tasks:
1899 task.set_downstream(to_tasks)
1900
1901
1902def chain_linear(*elements: DependencyMixin | Sequence[DependencyMixin]):
1903 """
1904 Simplify task dependency definition.
1905
1906 E.g.: suppose you want precedence like so::
1907
1908 ╭─op2─╮ ╭─op4─╮
1909 op1─┤ ├─├─op5─┤─op7
1910 ╰-op3─╯ ╰-op6─╯
1911
1912 Then you can accomplish like so::
1913
1914 chain_linear(op1, [op2, op3], [op4, op5, op6], op7)
1915
1916 :param elements: a list of operators / lists of operators
1917 """
1918 if not elements:
1919 raise ValueError("No tasks provided; nothing to do.")
1920 prev_elem = None
1921 deps_set = False
1922 for curr_elem in elements:
1923 if isinstance(curr_elem, EdgeModifier):
1924 raise ValueError("Labels are not supported by chain_linear")
1925 if prev_elem is not None:
1926 for task in prev_elem:
1927 task >> curr_elem
1928 if not deps_set:
1929 deps_set = True
1930 prev_elem = [curr_elem] if isinstance(curr_elem, DependencyMixin) else curr_elem
1931 if not deps_set:
1932 raise ValueError("No dependencies were set. Did you forget to expand with `*`?")