1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements. See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership. The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License. You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied. See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18from __future__ import annotations
19
20import abc
21import collections.abc
22import contextlib
23import copy
24import inspect
25import sys
26import warnings
27from collections.abc import Callable, Collection, Iterable, Mapping, Sequence
28from contextvars import ContextVar
29from dataclasses import dataclass, field
30from datetime import datetime, timedelta
31from enum import Enum
32from functools import total_ordering, wraps
33from types import FunctionType
34from typing import TYPE_CHECKING, Any, ClassVar, Final, NoReturn, TypeVar, cast
35
36import attrs
37
38from airflow.sdk import TriggerRule, timezone
39from airflow.sdk._shared.secrets_masker import redact
40from airflow.sdk.definitions._internal.abstractoperator import (
41 DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST,
42 DEFAULT_OWNER,
43 DEFAULT_POOL_NAME,
44 DEFAULT_POOL_SLOTS,
45 DEFAULT_PRIORITY_WEIGHT,
46 DEFAULT_QUEUE,
47 DEFAULT_RETRIES,
48 DEFAULT_RETRY_DELAY,
49 DEFAULT_TASK_EXECUTION_TIMEOUT,
50 DEFAULT_TRIGGER_RULE,
51 DEFAULT_WAIT_FOR_PAST_DEPENDS_BEFORE_SKIPPING,
52 DEFAULT_WEIGHT_RULE,
53 AbstractOperator,
54 DependencyMixin,
55 TaskStateChangeCallback,
56)
57from airflow.sdk.definitions._internal.decorators import fixup_decorator_warning_stack
58from airflow.sdk.definitions._internal.node import validate_key
59from airflow.sdk.definitions._internal.setup_teardown import SetupTeardownContext
60from airflow.sdk.definitions._internal.types import NOTSET, validate_instance_args
61from airflow.sdk.definitions.edges import EdgeModifier
62from airflow.sdk.definitions.mappedoperator import OperatorPartial, validate_mapping_kwargs
63from airflow.sdk.definitions.param import ParamsDict
64from airflow.sdk.exceptions import RemovedInAirflow4Warning
65from airflow.task.priority_strategy import (
66 PriorityWeightStrategy,
67 airflow_priority_weight_strategies,
68 validate_and_load_priority_weight_strategy,
69)
70
71# Databases do not support arbitrary precision integers, so we need to limit the range of priority weights.
72# postgres: -2147483648 to +2147483647 (see https://www.postgresql.org/docs/current/datatype-numeric.html)
73# mysql: -2147483648 to +2147483647 (see https://dev.mysql.com/doc/refman/8.4/en/integer-types.html)
74# sqlite: -9223372036854775808 to +9223372036854775807 (see https://sqlite.org/datatype3.html)
75DB_SAFE_MINIMUM = -2147483648
76DB_SAFE_MAXIMUM = 2147483647
77
78
79def db_safe_priority(priority_weight: int) -> int:
80 """Convert priority weight to a safe value for the database."""
81 return max(DB_SAFE_MINIMUM, min(DB_SAFE_MAXIMUM, priority_weight))
82
83
84C = TypeVar("C", bound=Callable)
85T = TypeVar("T", bound=FunctionType)
86
87if TYPE_CHECKING:
88 from types import ClassMethodDescriptorType
89
90 import jinja2
91 from typing_extensions import Self
92
93 from airflow.sdk.bases.operatorlink import BaseOperatorLink
94 from airflow.sdk.definitions.context import Context
95 from airflow.sdk.definitions.dag import DAG
96 from airflow.sdk.definitions.operator_resources import Resources
97 from airflow.sdk.definitions.taskgroup import TaskGroup
98 from airflow.sdk.definitions.xcom_arg import XComArg
99 from airflow.serialization.enums import DagAttributeTypes
100 from airflow.task.priority_strategy import PriorityWeightStrategy
101 from airflow.triggers.base import BaseTrigger, StartTriggerArgs
102
103 TaskPreExecuteHook = Callable[[Context], None]
104 TaskPostExecuteHook = Callable[[Context, Any], None]
105
106__all__ = [
107 "BaseOperator",
108 "chain",
109 "chain_linear",
110 "cross_downstream",
111]
112
113
114class TriggerFailureReason(str, Enum):
115 """
116 Reasons for trigger failures.
117
118 Internal use only.
119
120 :meta private:
121 """
122
123 TRIGGER_TIMEOUT = "Trigger timeout"
124 TRIGGER_FAILURE = "Trigger failure"
125
126
127TRIGGER_FAIL_REPR = "__fail__"
128"""String value to represent trigger failure.
129
130Internal use only.
131
132:meta private:
133"""
134
135
136def _get_parent_defaults(dag: DAG | None, task_group: TaskGroup | None) -> tuple[dict, ParamsDict]:
137 if not dag:
138 return {}, ParamsDict()
139 dag_args = copy.copy(dag.default_args)
140 dag_params = copy.deepcopy(dag.params)
141 dag_params._fill_missing_param_source("dag")
142 if task_group:
143 if task_group.default_args and not isinstance(task_group.default_args, collections.abc.Mapping):
144 raise TypeError("default_args must be a mapping")
145 dag_args.update(task_group.default_args)
146 return dag_args, dag_params
147
148
149def get_merged_defaults(
150 dag: DAG | None,
151 task_group: TaskGroup | None,
152 task_params: collections.abc.MutableMapping | None,
153 task_default_args: dict | None,
154) -> tuple[dict, ParamsDict]:
155 args, params = _get_parent_defaults(dag, task_group)
156 if task_params:
157 if not isinstance(task_params, collections.abc.Mapping):
158 raise TypeError(f"params must be a mapping, got {type(task_params)}")
159
160 task_params = ParamsDict(task_params)
161 task_params._fill_missing_param_source("task")
162 params.update(task_params)
163
164 if task_default_args:
165 if not isinstance(task_default_args, collections.abc.Mapping):
166 raise TypeError(f"default_args must be a mapping, got {type(task_params)}")
167 args.update(task_default_args)
168 with contextlib.suppress(KeyError):
169 if params_from_default_args := ParamsDict(task_default_args["params"] or {}):
170 params_from_default_args._fill_missing_param_source("task")
171 params.update(params_from_default_args)
172
173 return args, params
174
175
176def parse_retries(retries: Any) -> int | None:
177 if retries is None:
178 return 0
179 if type(retries) == int: # noqa: E721
180 return retries
181 try:
182 parsed_retries = int(retries)
183 except (TypeError, ValueError):
184 raise RuntimeError(f"'retries' type must be int, not {type(retries).__name__}")
185 return parsed_retries
186
187
188def coerce_timedelta(value: float | timedelta, *, key: str | None = None) -> timedelta:
189 if isinstance(value, timedelta):
190 return value
191 return timedelta(seconds=value)
192
193
194def coerce_resources(resources: dict[str, Any] | None) -> Resources | None:
195 if resources is None:
196 return None
197 from airflow.sdk.definitions.operator_resources import Resources
198
199 return Resources(**resources)
200
201
202class _PartialDescriptor:
203 """A descriptor that guards against ``.partial`` being called on Task objects."""
204
205 class_method: ClassMethodDescriptorType | None = None
206
207 def __get__(
208 self, obj: BaseOperator, cls: type[BaseOperator] | None = None
209 ) -> Callable[..., OperatorPartial]:
210 # Call this "partial" so it looks nicer in stack traces.
211 def partial(**kwargs):
212 raise TypeError("partial can only be called on Operator classes, not Tasks themselves")
213
214 if obj is not None:
215 return partial
216 return self.class_method.__get__(cls, cls)
217
218
219OPERATOR_DEFAULTS: dict[str, Any] = {
220 "allow_nested_operators": True,
221 "depends_on_past": False,
222 "email_on_failure": True,
223 "email_on_retry": True,
224 "execution_timeout": DEFAULT_TASK_EXECUTION_TIMEOUT,
225 # "executor": DEFAULT_EXECUTOR,
226 "executor_config": {},
227 "ignore_first_depends_on_past": DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST,
228 "inlets": [],
229 "map_index_template": None,
230 "on_execute_callback": [],
231 "on_failure_callback": [],
232 "on_retry_callback": [],
233 "on_skipped_callback": [],
234 "on_success_callback": [],
235 "outlets": [],
236 "owner": DEFAULT_OWNER,
237 "pool_slots": DEFAULT_POOL_SLOTS,
238 "priority_weight": DEFAULT_PRIORITY_WEIGHT,
239 "queue": DEFAULT_QUEUE,
240 "retries": DEFAULT_RETRIES,
241 "retry_delay": DEFAULT_RETRY_DELAY,
242 "retry_exponential_backoff": 0,
243 "trigger_rule": DEFAULT_TRIGGER_RULE,
244 "wait_for_past_depends_before_skipping": DEFAULT_WAIT_FOR_PAST_DEPENDS_BEFORE_SKIPPING,
245 "wait_for_downstream": False,
246 "weight_rule": DEFAULT_WEIGHT_RULE,
247}
248
249
250# This is what handles the actual mapping.
251
252if TYPE_CHECKING:
253
254 def partial(
255 operator_class: type[BaseOperator],
256 *,
257 task_id: str,
258 dag: DAG | None = None,
259 task_group: TaskGroup | None = None,
260 start_date: datetime = ...,
261 end_date: datetime = ...,
262 owner: str = ...,
263 email: None | str | Iterable[str] = ...,
264 params: collections.abc.MutableMapping | None = None,
265 resources: dict[str, Any] | None = ...,
266 trigger_rule: str = ...,
267 depends_on_past: bool = ...,
268 ignore_first_depends_on_past: bool = ...,
269 wait_for_past_depends_before_skipping: bool = ...,
270 wait_for_downstream: bool = ...,
271 retries: int | None = ...,
272 queue: str = ...,
273 pool: str = ...,
274 pool_slots: int = ...,
275 execution_timeout: timedelta | None = ...,
276 max_retry_delay: None | timedelta | float = ...,
277 retry_delay: timedelta | float = ...,
278 retry_exponential_backoff: float = ...,
279 priority_weight: int = ...,
280 weight_rule: str | PriorityWeightStrategy = ...,
281 sla: timedelta | None = ...,
282 map_index_template: str | None = ...,
283 max_active_tis_per_dag: int | None = ...,
284 max_active_tis_per_dagrun: int | None = ...,
285 on_execute_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = ...,
286 on_failure_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = ...,
287 on_success_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = ...,
288 on_retry_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = ...,
289 on_skipped_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = ...,
290 run_as_user: str | None = ...,
291 executor: str | None = ...,
292 executor_config: dict | None = ...,
293 inlets: Any | None = ...,
294 outlets: Any | None = ...,
295 doc: str | None = ...,
296 doc_md: str | None = ...,
297 doc_json: str | None = ...,
298 doc_yaml: str | None = ...,
299 doc_rst: str | None = ...,
300 task_display_name: str | None = ...,
301 logger_name: str | None = ...,
302 allow_nested_operators: bool = True,
303 **kwargs,
304 ) -> OperatorPartial: ...
305else:
306
307 def partial(
308 operator_class: type[BaseOperator],
309 *,
310 task_id: str,
311 dag: DAG | None = None,
312 task_group: TaskGroup | None = None,
313 params: collections.abc.MutableMapping | None = None,
314 **kwargs,
315 ):
316 from airflow.sdk.definitions._internal.contextmanager import DagContext, TaskGroupContext
317
318 validate_mapping_kwargs(operator_class, "partial", kwargs)
319
320 dag = dag or DagContext.get_current()
321 if dag:
322 task_group = task_group or TaskGroupContext.get_current(dag)
323 if task_group:
324 task_id = task_group.child_id(task_id)
325
326 # Merge Dag and task group level defaults into user-supplied values.
327 dag_default_args, partial_params = get_merged_defaults(
328 dag=dag,
329 task_group=task_group,
330 task_params=params,
331 task_default_args=kwargs.pop("default_args", None),
332 )
333
334 # Create partial_kwargs from args and kwargs
335 partial_kwargs: dict[str, Any] = {
336 "task_id": task_id,
337 "dag": dag,
338 "task_group": task_group,
339 **kwargs,
340 }
341
342 # Inject Dag-level default args into args provided to this function.
343 # Most of the default args will be retrieved during unmapping; here we
344 # only ensure base properties are correctly set for the scheduler.
345 partial_kwargs.update(
346 (k, v)
347 for k, v in dag_default_args.items()
348 if k not in partial_kwargs and k in BaseOperator.__init__._BaseOperatorMeta__param_names
349 )
350
351 # Fill fields not provided by the user with default values.
352 partial_kwargs.update((k, v) for k, v in OPERATOR_DEFAULTS.items() if k not in partial_kwargs)
353
354 # Post-process arguments. Should be kept in sync with _TaskDecorator.expand().
355 if "task_concurrency" in kwargs: # Reject deprecated option.
356 raise TypeError("unexpected argument: task_concurrency")
357 if start_date := partial_kwargs.get("start_date", None):
358 partial_kwargs["start_date"] = timezone.convert_to_utc(start_date)
359 if end_date := partial_kwargs.get("end_date", None):
360 partial_kwargs["end_date"] = timezone.convert_to_utc(end_date)
361 if partial_kwargs["pool_slots"] < 1:
362 dag_str = ""
363 if dag:
364 dag_str = f" in dag {dag.dag_id}"
365 raise ValueError(f"pool slots for {task_id}{dag_str} cannot be less than 1")
366 if retries := partial_kwargs.get("retries"):
367 partial_kwargs["retries"] = BaseOperator._convert_retries(retries)
368 partial_kwargs["retry_delay"] = BaseOperator._convert_retry_delay(partial_kwargs["retry_delay"])
369 partial_kwargs["max_retry_delay"] = BaseOperator._convert_max_retry_delay(
370 partial_kwargs.get("max_retry_delay", None)
371 )
372
373 for k in ("execute", "failure", "success", "retry", "skipped"):
374 partial_kwargs[attr] = _collect_from_input(partial_kwargs.get(attr := f"on_{k}_callback"))
375
376 return OperatorPartial(
377 operator_class=operator_class,
378 kwargs=partial_kwargs,
379 params=partial_params,
380 )
381
382
383class ExecutorSafeguard:
384 """
385 The ExecutorSafeguard decorator.
386
387 Checks if the execute method of an operator isn't manually called outside
388 the TaskInstance as we want to avoid bad mixing between decorated and
389 classic operators.
390 """
391
392 test_mode: ClassVar[bool] = False
393 tracker: ClassVar[ContextVar[BaseOperator]] = ContextVar("ExecutorSafeguard_sentinel")
394 sentinel_value: ClassVar[object] = object()
395
396 @classmethod
397 def decorator(cls, func):
398 @wraps(func)
399 def wrapper(self, *args, **kwargs):
400 sentinel_key = f"{self.__class__.__name__}__sentinel"
401 sentinel = kwargs.pop(sentinel_key, None)
402
403 with contextlib.ExitStack() as stack:
404 if sentinel is cls.sentinel_value:
405 token = cls.tracker.set(self)
406 sentinel = self
407 stack.callback(cls.tracker.reset, token)
408 else:
409 # No sentinel passed in, maybe the subclass execute did have it passed?
410 sentinel = cls.tracker.get(None)
411
412 if not cls.test_mode and sentinel is not self:
413 message = f"{self.__class__.__name__}.{func.__name__} cannot be called outside of the Task Runner!"
414 if not self.allow_nested_operators:
415 raise RuntimeError(message)
416 self.log.warning(message)
417
418 # Now that we've logged, set sentinel so that `super()` calls don't log again
419 token = cls.tracker.set(self)
420 stack.callback(cls.tracker.reset, token)
421
422 return func(self, *args, **kwargs)
423
424 return wrapper
425
426
427if "airflow.configuration" in sys.modules:
428 # Don't try and import it if it's not already loaded
429 from airflow.sdk.configuration import conf
430
431 ExecutorSafeguard.test_mode = conf.getboolean("core", "unit_test_mode")
432
433
434def _collect_from_input(value_or_values: None | C | Collection[C]) -> list[C]:
435 if not value_or_values:
436 return []
437 if isinstance(value_or_values, Collection):
438 return list(value_or_values)
439 return [value_or_values]
440
441
442class BaseOperatorMeta(abc.ABCMeta):
443 """Metaclass of BaseOperator."""
444
445 @classmethod
446 def _apply_defaults(cls, func: T) -> T:
447 """
448 Look for an argument named "default_args", and fill the unspecified arguments from it.
449
450 Since python2.* isn't clear about which arguments are missing when
451 calling a function, and that this can be quite confusing with multi-level
452 inheritance and argument defaults, this decorator also alerts with
453 specific information about the missing arguments.
454 """
455 # Cache inspect.signature for the wrapper closure to avoid calling it
456 # at every decorated invocation. This is separate sig_cache created
457 # per decoration, i.e. each function decorated using apply_defaults will
458 # have a different sig_cache.
459 sig_cache = inspect.signature(func)
460 non_variadic_params = {
461 name: param
462 for (name, param) in sig_cache.parameters.items()
463 if param.name != "self" and param.kind not in (param.VAR_POSITIONAL, param.VAR_KEYWORD)
464 }
465 non_optional_args = {
466 name
467 for name, param in non_variadic_params.items()
468 if param.default == param.empty and name != "task_id"
469 }
470
471 fixup_decorator_warning_stack(func)
472
473 @wraps(func)
474 def apply_defaults(self: BaseOperator, *args: Any, **kwargs: Any) -> Any:
475 from airflow.sdk.definitions._internal.contextmanager import DagContext, TaskGroupContext
476
477 if args:
478 raise TypeError("Use keyword arguments when initializing operators")
479
480 instantiated_from_mapped = kwargs.pop(
481 "_airflow_from_mapped",
482 getattr(self, "_BaseOperator__from_mapped", False),
483 )
484
485 dag: DAG | None = kwargs.get("dag")
486 if dag is None:
487 dag = DagContext.get_current()
488 if dag is not None:
489 kwargs["dag"] = dag
490
491 task_group: TaskGroup | None = kwargs.get("task_group")
492 if dag and not task_group:
493 task_group = TaskGroupContext.get_current(dag)
494 if task_group is not None:
495 kwargs["task_group"] = task_group
496
497 default_args, merged_params = get_merged_defaults(
498 dag=dag,
499 task_group=task_group,
500 task_params=kwargs.pop("params", None),
501 task_default_args=kwargs.pop("default_args", None),
502 )
503
504 for arg in sig_cache.parameters:
505 if arg not in kwargs and arg in default_args:
506 kwargs[arg] = default_args[arg]
507
508 missing_args = non_optional_args.difference(kwargs)
509 if len(missing_args) == 1:
510 raise TypeError(f"missing keyword argument {missing_args.pop()!r}")
511 if missing_args:
512 display = ", ".join(repr(a) for a in sorted(missing_args))
513 raise TypeError(f"missing keyword arguments {display}")
514
515 if merged_params:
516 kwargs["params"] = merged_params
517
518 hook = getattr(self, "_hook_apply_defaults", None)
519 if hook:
520 args, kwargs = hook(**kwargs, default_args=default_args)
521 default_args = kwargs.pop("default_args", {})
522
523 if not hasattr(self, "_BaseOperator__init_kwargs"):
524 object.__setattr__(self, "_BaseOperator__init_kwargs", {})
525 object.__setattr__(self, "_BaseOperator__from_mapped", instantiated_from_mapped)
526
527 result = func(self, **kwargs, default_args=default_args)
528
529 # Store the args passed to init -- we need them to support task.map serialization!
530 self._BaseOperator__init_kwargs.update(kwargs) # type: ignore
531
532 # Set upstream task defined by XComArgs passed to template fields of the operator.
533 # BUT: only do this _ONCE_, not once for each class in the hierarchy
534 if not instantiated_from_mapped and func == self.__init__.__wrapped__: # type: ignore[misc]
535 self._set_xcomargs_dependencies()
536 # Mark instance as instantiated so that future attr setting updates xcomarg-based deps.
537 object.__setattr__(self, "_BaseOperator__instantiated", True)
538
539 return result
540
541 apply_defaults.__non_optional_args = non_optional_args # type: ignore
542 apply_defaults.__param_names = set(non_variadic_params) # type: ignore
543
544 return cast("T", apply_defaults)
545
546 def __new__(cls, name, bases, namespace, **kwargs):
547 execute_method = namespace.get("execute")
548 if callable(execute_method) and not getattr(execute_method, "__isabstractmethod__", False):
549 namespace["execute"] = ExecutorSafeguard.decorator(execute_method)
550 new_cls = super().__new__(cls, name, bases, namespace, **kwargs)
551 with contextlib.suppress(KeyError):
552 # Update the partial descriptor with the class method, so it calls the actual function
553 # (but let subclasses override it if they need to)
554 partial_desc = vars(new_cls)["partial"]
555 if isinstance(partial_desc, _PartialDescriptor):
556 partial_desc.class_method = classmethod(partial)
557
558 # We patch `__init__` only if the class defines it.
559 first_superclass = new_cls.mro()[1]
560 if new_cls.__init__ is not first_superclass.__init__:
561 new_cls.__init__ = cls._apply_defaults(new_cls.__init__)
562
563 return new_cls
564
565
566# TODO: The following mapping is used to validate that the arguments passed to the BaseOperator are of the
567# correct type. This is a temporary solution until we find a more sophisticated method for argument
568# validation. One potential method is to use `get_type_hints` from the typing module. However, this is not
569# fully compatible with future annotations for Python versions below 3.10. Once we require a minimum Python
570# version that supports `get_type_hints` effectively or find a better approach, we can replace this
571# manual type-checking method.
572BASEOPERATOR_ARGS_EXPECTED_TYPES = {
573 "task_id": str,
574 "email": (str, Sequence),
575 "email_on_retry": bool,
576 "email_on_failure": bool,
577 "retries": int,
578 "retry_exponential_backoff": (int, float),
579 "depends_on_past": bool,
580 "ignore_first_depends_on_past": bool,
581 "wait_for_past_depends_before_skipping": bool,
582 "wait_for_downstream": bool,
583 "priority_weight": int,
584 "queue": str,
585 "pool": str,
586 "pool_slots": int,
587 "trigger_rule": str,
588 "run_as_user": str,
589 "task_concurrency": int,
590 "map_index_template": str,
591 "max_active_tis_per_dag": int,
592 "max_active_tis_per_dagrun": int,
593 "executor": str,
594 "do_xcom_push": bool,
595 "multiple_outputs": bool,
596 "doc": str,
597 "doc_md": str,
598 "doc_json": str,
599 "doc_yaml": str,
600 "doc_rst": str,
601 "task_display_name": str,
602 "logger_name": str,
603 "allow_nested_operators": bool,
604 "start_date": datetime,
605 "end_date": datetime,
606}
607
608
609# Note: BaseOperator is defined as a dataclass, and not an attrs class as we do too much metaprogramming in
610# here (metaclass, custom `__setattr__` behaviour) and this fights with attrs too much to make it worth it.
611#
612# To future reader: if you want to try and make this a "normal" attrs class, go ahead and attempt it. If you
613# get nowhere leave your record here for the next poor soul and what problems you ran in to.
614#
615# @ashb, 2024/10/14
616# - "Can't combine custom __setattr__ with on_setattr hooks"
617# - Setting class-wide `define(on_setarrs=...)` isn't called for non-attrs subclasses
618@total_ordering
619@dataclass(repr=False)
620class BaseOperator(AbstractOperator, metaclass=BaseOperatorMeta):
621 r"""
622 Abstract base class for all operators.
623
624 Since operators create objects that become nodes in the Dag, BaseOperator
625 contains many recursive methods for Dag crawling behavior. To derive from
626 this class, you are expected to override the constructor and the 'execute'
627 method.
628
629 Operators derived from this class should perform or trigger certain tasks
630 synchronously (wait for completion). Example of operators could be an
631 operator that runs a Pig job (PigOperator), a sensor operator that
632 waits for a partition to land in Hive (HiveSensorOperator), or one that
633 moves data from Hive to MySQL (Hive2MySqlOperator). Instances of these
634 operators (tasks) target specific operations, running specific scripts,
635 functions or data transfers.
636
637 This class is abstract and shouldn't be instantiated. Instantiating a
638 class derived from this one results in the creation of a task object,
639 which ultimately becomes a node in Dag objects. Task dependencies should
640 be set by using the set_upstream and/or set_downstream methods.
641
642 :param task_id: a unique, meaningful id for the task
643 :param owner: the owner of the task. Using a meaningful description
644 (e.g. user/person/team/role name) to clarify ownership is recommended.
645 :param email: the 'to' email address(es) used in email alerts. This can be a
646 single email or multiple ones. Multiple addresses can be specified as a
647 comma or semicolon separated string or by passing a list of strings. (deprecated)
648 :param email_on_retry: Indicates whether email alerts should be sent when a
649 task is retried (deprecated)
650 :param email_on_failure: Indicates whether email alerts should be sent when
651 a task failed (deprecated)
652 :param retries: the number of retries that should be performed before
653 failing the task
654 :param retry_delay: delay between retries, can be set as ``timedelta`` or
655 ``float`` seconds, which will be converted into ``timedelta``,
656 the default is ``timedelta(seconds=300)``.
657 :param retry_exponential_backoff: multiplier for exponential backoff between retries.
658 Set to 0 to disable (constant delay). Set to 2.0 for standard exponential backoff
659 (delay doubles with each retry). For example, with retry_delay=4min and
660 retry_exponential_backoff=5, retries occur after 4min, 20min, 100min, etc.
661 :param max_retry_delay: maximum delay interval between retries, can be set as
662 ``timedelta`` or ``float`` seconds, which will be converted into ``timedelta``.
663 :param start_date: The ``start_date`` for the task, determines
664 the ``logical_date`` for the first task instance. The best practice
665 is to have the start_date rounded
666 to your Dag's ``schedule_interval``. Daily jobs have their start_date
667 some day at 00:00:00, hourly jobs have their start_date at 00:00
668 of a specific hour. Note that Airflow simply looks at the latest
669 ``logical_date`` and adds the ``schedule_interval`` to determine
670 the next ``logical_date``. It is also very important
671 to note that different tasks' dependencies
672 need to line up in time. If task A depends on task B and their
673 start_date are offset in a way that their logical_date don't line
674 up, A's dependencies will never be met. If you are looking to delay
675 a task, for example running a daily task at 2AM, look into the
676 ``TimeSensor`` and ``TimeDeltaSensor``. We advise against using
677 dynamic ``start_date`` and recommend using fixed ones. Read the
678 FAQ entry about start_date for more information.
679 :param end_date: if specified, the scheduler won't go beyond this date
680 :param depends_on_past: when set to true, task instances will run
681 sequentially and only if the previous instance has succeeded or has been skipped.
682 The task instance for the start_date is allowed to run.
683 :param wait_for_past_depends_before_skipping: when set to true, if the task instance
684 should be marked as skipped, and depends_on_past is true, the ti will stay on None state
685 waiting the task of the previous run
686 :param wait_for_downstream: when set to true, an instance of task
687 X will wait for tasks immediately downstream of the previous instance
688 of task X to finish successfully or be skipped before it runs. This is useful if the
689 different instances of a task X alter the same asset, and this asset
690 is used by tasks downstream of task X. Note that depends_on_past
691 is forced to True wherever wait_for_downstream is used. Also note that
692 only tasks *immediately* downstream of the previous task instance are waited
693 for; the statuses of any tasks further downstream are ignored.
694 :param dag: a reference to the dag the task is attached to (if any)
695 :param priority_weight: priority weight of this task against other task.
696 This allows the executor to trigger higher priority tasks before
697 others when things get backed up. Set priority_weight as a higher
698 number for more important tasks.
699 As not all database engines support 64-bit integers, values are capped with 32-bit.
700 Valid range is from -2,147,483,648 to 2,147,483,647.
701 :param weight_rule: weighting method used for the effective total
702 priority weight of the task. Options are:
703 ``{ downstream | upstream | absolute }`` default is ``downstream``
704 When set to ``downstream`` the effective weight of the task is the
705 aggregate sum of all downstream descendants. As a result, upstream
706 tasks will have higher weight and will be scheduled more aggressively
707 when using positive weight values. This is useful when you have
708 multiple dag run instances and desire to have all upstream tasks to
709 complete for all runs before each dag can continue processing
710 downstream tasks. When set to ``upstream`` the effective weight is the
711 aggregate sum of all upstream ancestors. This is the opposite where
712 downstream tasks have higher weight and will be scheduled more
713 aggressively when using positive weight values. This is useful when you
714 have multiple dag run instances and prefer to have each dag complete
715 before starting upstream tasks of other dags. When set to
716 ``absolute``, the effective weight is the exact ``priority_weight``
717 specified without additional weighting. You may want to do this when
718 you know exactly what priority weight each task should have.
719 Additionally, when set to ``absolute``, there is bonus effect of
720 significantly speeding up the task creation process as for very large
721 Dags. Options can be set as string or using the constants defined in
722 the static class ``airflow.utils.WeightRule``.
723 Irrespective of the weight rule, resulting priority values are capped with 32-bit.
724 |experimental|
725 Since 2.9.0, Airflow allows to define custom priority weight strategy,
726 by creating a subclass of
727 ``airflow.task.priority_strategy.PriorityWeightStrategy`` and registering
728 in a plugin, then providing the class path or the class instance via
729 ``weight_rule`` parameter. The custom priority weight strategy will be
730 used to calculate the effective total priority weight of the task instance.
731 :param queue: which queue to target when running this job. Not
732 all executors implement queue management, the CeleryExecutor
733 does support targeting specific queues.
734 :param pool: the slot pool this task should run in, slot pools are a
735 way to limit concurrency for certain tasks
736 :param pool_slots: the number of pool slots this task should use (>= 1)
737 Values less than 1 are not allowed.
738 :param sla: DEPRECATED - The SLA feature is removed in Airflow 3.0, to be replaced with a
739 new implementation in Airflow >=3.1.
740 :param execution_timeout: max time allowed for the execution of
741 this task instance, if it goes beyond it will raise and fail.
742 :param on_failure_callback: a function or list of functions to be called when a task instance
743 of this task fails. a context dictionary is passed as a single
744 parameter to this function. Context contains references to related
745 objects to the task instance and is documented under the macros
746 section of the API.
747 :param on_execute_callback: much like the ``on_failure_callback`` except
748 that it is executed right before the task is executed.
749 :param on_retry_callback: much like the ``on_failure_callback`` except
750 that it is executed when retries occur.
751 :param on_success_callback: much like the ``on_failure_callback`` except
752 that it is executed when the task succeeds.
753 :param on_skipped_callback: much like the ``on_failure_callback`` except
754 that it is executed when skipped occur; this callback will be called only if AirflowSkipException get raised.
755 Explicitly it is NOT called if a task is not started to be executed because of a preceding branching
756 decision in the Dag or a trigger rule which causes execution to skip so that the task execution
757 is never scheduled.
758 :param pre_execute: a function to be called immediately before task
759 execution, receiving a context dictionary; raising an exception will
760 prevent the task from being executed.
761 :param post_execute: a function to be called immediately after task
762 execution, receiving a context dictionary and task result; raising an
763 exception will prevent the task from succeeding.
764 :param trigger_rule: defines the rule by which dependencies are applied
765 for the task to get triggered. Options are:
766 ``{ all_success | all_failed | all_done | all_skipped | one_success | one_done |
767 one_failed | none_failed | none_failed_min_one_success | none_skipped | always}``
768 default is ``all_success``. Options can be set as string or
769 using the constants defined in the static class
770 ``airflow.utils.TriggerRule``
771 :param resources: A map of resource parameter names (the argument names of the
772 Resources constructor) to their values.
773 :param run_as_user: unix username to impersonate while running the task
774 :param max_active_tis_per_dag: When set, a task will be able to limit the concurrent
775 runs across logical_dates.
776 :param max_active_tis_per_dagrun: When set, a task will be able to limit the concurrent
777 task instances per Dag run.
778 :param executor: Which executor to target when running this task. NOT YET SUPPORTED
779 :param executor_config: Additional task-level configuration parameters that are
780 interpreted by a specific executor. Parameters are namespaced by the name of
781 executor.
782
783 **Example**: to run this task in a specific docker container through
784 the KubernetesExecutor ::
785
786 MyOperator(..., executor_config={"KubernetesExecutor": {"image": "myCustomDockerImage"}})
787
788 :param do_xcom_push: if True, an XCom is pushed containing the Operator's
789 result
790 :param multiple_outputs: if True and do_xcom_push is True, pushes multiple XComs, one for each
791 key in the returned dictionary result. If False and do_xcom_push is True, pushes a single XCom.
792 :param task_group: The TaskGroup to which the task should belong. This is typically provided when not
793 using a TaskGroup as a context manager.
794 :param doc: Add documentation or notes to your Task objects that is visible in
795 Task Instance details View in the Webserver
796 :param doc_md: Add documentation (in Markdown format) or notes to your Task objects
797 that is visible in Task Instance details View in the Webserver
798 :param doc_rst: Add documentation (in RST format) or notes to your Task objects
799 that is visible in Task Instance details View in the Webserver
800 :param doc_json: Add documentation (in JSON format) or notes to your Task objects
801 that is visible in Task Instance details View in the Webserver
802 :param doc_yaml: Add documentation (in YAML format) or notes to your Task objects
803 that is visible in Task Instance details View in the Webserver
804 :param task_display_name: The display name of the task which appears on the UI.
805 :param logger_name: Name of the logger used by the Operator to emit logs.
806 If set to `None` (default), the logger name will fall back to
807 `airflow.task.operators.{class.__module__}.{class.__name__}` (e.g. HttpOperator will have
808 *airflow.task.operators.airflow.providers.http.operators.http.HttpOperator* as logger).
809 :param allow_nested_operators: if True, when an operator is executed within another one a warning message
810 will be logged. If False, then an exception will be raised if the operator is badly used (e.g. nested
811 within another one). In future releases of Airflow this parameter will be removed and an exception
812 will always be thrown when operators are nested within each other (default is True).
813
814 **Example**: example of a bad operator mixin usage::
815
816 @task(provide_context=True)
817 def say_hello_world(**context):
818 hello_world_task = BashOperator(
819 task_id="hello_world_task",
820 bash_command="python -c \"print('Hello, world!')\"",
821 dag=dag,
822 )
823 hello_world_task.execute(context)
824 """
825
826 task_id: str
827 owner: str = DEFAULT_OWNER
828 email: str | Sequence[str] | None = None
829 email_on_retry: bool = True
830 email_on_failure: bool = True
831 retries: int | None = DEFAULT_RETRIES
832 retry_delay: timedelta = DEFAULT_RETRY_DELAY
833 retry_exponential_backoff: float = 0
834 max_retry_delay: timedelta | float | None = None
835 start_date: datetime | None = None
836 end_date: datetime | None = None
837 depends_on_past: bool = False
838 ignore_first_depends_on_past: bool = DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST
839 wait_for_past_depends_before_skipping: bool = DEFAULT_WAIT_FOR_PAST_DEPENDS_BEFORE_SKIPPING
840 wait_for_downstream: bool = False
841
842 # At execution_time this becomes a normal dict
843 params: ParamsDict | dict = field(default_factory=ParamsDict)
844 default_args: dict | None = None
845 priority_weight: int = DEFAULT_PRIORITY_WEIGHT
846 weight_rule: PriorityWeightStrategy = field(
847 default_factory=airflow_priority_weight_strategies[DEFAULT_WEIGHT_RULE]
848 )
849 queue: str = DEFAULT_QUEUE
850 pool: str = DEFAULT_POOL_NAME
851 pool_slots: int = DEFAULT_POOL_SLOTS
852 execution_timeout: timedelta | None = DEFAULT_TASK_EXECUTION_TIMEOUT
853 on_execute_callback: Sequence[TaskStateChangeCallback] = ()
854 on_failure_callback: Sequence[TaskStateChangeCallback] = ()
855 on_success_callback: Sequence[TaskStateChangeCallback] = ()
856 on_retry_callback: Sequence[TaskStateChangeCallback] = ()
857 on_skipped_callback: Sequence[TaskStateChangeCallback] = ()
858 _pre_execute_hook: TaskPreExecuteHook | None = None
859 _post_execute_hook: TaskPostExecuteHook | None = None
860 trigger_rule: TriggerRule = DEFAULT_TRIGGER_RULE
861 resources: dict[str, Any] | None = None
862 run_as_user: str | None = None
863 task_concurrency: int | None = None
864 map_index_template: str | None = None
865 max_active_tis_per_dag: int | None = None
866 max_active_tis_per_dagrun: int | None = None
867 executor: str | None = None
868 executor_config: dict | None = None
869 do_xcom_push: bool = True
870 multiple_outputs: bool = False
871 inlets: list[Any] = field(default_factory=list)
872 outlets: list[Any] = field(default_factory=list)
873 task_group: TaskGroup | None = None
874 doc: str | None = None
875 doc_md: str | None = None
876 doc_json: str | None = None
877 doc_yaml: str | None = None
878 doc_rst: str | None = None
879 _task_display_name: str | None = None
880 logger_name: str | None = None
881 allow_nested_operators: bool = True
882
883 is_setup: bool = False
884 is_teardown: bool = False
885
886 # TODO: Task-SDK: Make these ClassVar[]?
887 template_fields: Collection[str] = ()
888 template_ext: Sequence[str] = ()
889
890 template_fields_renderers: ClassVar[dict[str, str]] = {}
891
892 operator_extra_links: Collection[BaseOperatorLink] = ()
893
894 # Defines the color in the UI
895 ui_color: str = "#fff"
896 ui_fgcolor: str = "#000"
897
898 partial: Callable[..., OperatorPartial] = _PartialDescriptor() # type: ignore
899
900 _dag: DAG | None = field(init=False, default=None)
901
902 # Make this optional so the type matches the one define in LoggingMixin
903 _log_config_logger_name: str | None = field(default="airflow.task.operators", init=False)
904 _logger_name: str | None = None
905
906 # The _serialized_fields are lazily loaded when get_serialized_fields() method is called
907 __serialized_fields: ClassVar[frozenset[str] | None] = None
908
909 _comps: ClassVar[set[str]] = {
910 "task_id",
911 "dag_id",
912 "owner",
913 "email",
914 "email_on_retry",
915 "retry_delay",
916 "retry_exponential_backoff",
917 "max_retry_delay",
918 "start_date",
919 "end_date",
920 "depends_on_past",
921 "wait_for_downstream",
922 "priority_weight",
923 "execution_timeout",
924 "has_on_execute_callback",
925 "has_on_failure_callback",
926 "has_on_success_callback",
927 "has_on_retry_callback",
928 "has_on_skipped_callback",
929 "do_xcom_push",
930 "multiple_outputs",
931 "allow_nested_operators",
932 "executor",
933 }
934
935 # If True, the Rendered Template fields will be overwritten in DB after execution
936 # This is useful for Taskflow decorators that modify the template fields during execution like
937 # @task.bash decorator.
938 overwrite_rtif_after_execution: bool = False
939
940 # If True then the class constructor was called
941 __instantiated: bool = False
942 # List of args as passed to `init()`, after apply_defaults() has been updated. Used to "recreate" the task
943 # when mapping
944 # Set via the metaclass
945 __init_kwargs: dict[str, Any] = field(init=False)
946
947 # Set to True before calling execute method
948 _lock_for_execution: bool = False
949
950 # Set to True for an operator instantiated by a mapped operator.
951 __from_mapped: bool = False
952
953 start_trigger_args: StartTriggerArgs | None = None
954 start_from_trigger: bool = False
955
956 # base list which includes all the attrs that don't need deep copy.
957 _base_operator_shallow_copy_attrs: Final[tuple[str, ...]] = (
958 "user_defined_macros",
959 "user_defined_filters",
960 "params",
961 )
962
963 # each operator should override this class attr for shallow copy attrs.
964 shallow_copy_attrs: Sequence[str] = ()
965
966 def __setattr__(self: BaseOperator, key: str, value: Any):
967 if converter := getattr(self, f"_convert_{key}", None):
968 value = converter(value)
969 super().__setattr__(key, value)
970 if self.__from_mapped or self._lock_for_execution:
971 return # Skip any custom behavior for validation and during execute.
972 if key in self.__init_kwargs:
973 self.__init_kwargs[key] = value
974 if self.__instantiated and key in self.template_fields:
975 # Resolve upstreams set by assigning an XComArg after initializing
976 # an operator, example:
977 # op = BashOperator()
978 # op.bash_command = "sleep 1"
979 self._set_xcomargs_dependency(key, value)
980
981 def __init__(
982 self,
983 *,
984 task_id: str,
985 owner: str = DEFAULT_OWNER,
986 email: str | Sequence[str] | None = None,
987 email_on_retry: bool = True,
988 email_on_failure: bool = True,
989 retries: int | None = DEFAULT_RETRIES,
990 retry_delay: timedelta | float = DEFAULT_RETRY_DELAY,
991 retry_exponential_backoff: float = 0,
992 max_retry_delay: timedelta | float | None = None,
993 start_date: datetime | None = None,
994 end_date: datetime | None = None,
995 depends_on_past: bool = False,
996 ignore_first_depends_on_past: bool = DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST,
997 wait_for_past_depends_before_skipping: bool = DEFAULT_WAIT_FOR_PAST_DEPENDS_BEFORE_SKIPPING,
998 wait_for_downstream: bool = False,
999 dag: DAG | None = None,
1000 params: collections.abc.MutableMapping[str, Any] | None = None,
1001 default_args: dict | None = None,
1002 priority_weight: int = DEFAULT_PRIORITY_WEIGHT,
1003 weight_rule: str | PriorityWeightStrategy = DEFAULT_WEIGHT_RULE,
1004 queue: str = DEFAULT_QUEUE,
1005 pool: str | None = None,
1006 pool_slots: int = DEFAULT_POOL_SLOTS,
1007 sla: timedelta | None = None,
1008 execution_timeout: timedelta | None = DEFAULT_TASK_EXECUTION_TIMEOUT,
1009 on_execute_callback: None | TaskStateChangeCallback | Collection[TaskStateChangeCallback] = None,
1010 on_failure_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = None,
1011 on_success_callback: None | TaskStateChangeCallback | Collection[TaskStateChangeCallback] = None,
1012 on_retry_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = None,
1013 on_skipped_callback: None | TaskStateChangeCallback | Collection[TaskStateChangeCallback] = None,
1014 pre_execute: TaskPreExecuteHook | None = None,
1015 post_execute: TaskPostExecuteHook | None = None,
1016 trigger_rule: str = DEFAULT_TRIGGER_RULE,
1017 resources: dict[str, Any] | None = None,
1018 run_as_user: str | None = None,
1019 map_index_template: str | None = None,
1020 max_active_tis_per_dag: int | None = None,
1021 max_active_tis_per_dagrun: int | None = None,
1022 executor: str | None = None,
1023 executor_config: dict | None = None,
1024 do_xcom_push: bool = True,
1025 multiple_outputs: bool = False,
1026 inlets: Any | None = None,
1027 outlets: Any | None = None,
1028 task_group: TaskGroup | None = None,
1029 doc: str | None = None,
1030 doc_md: str | None = None,
1031 doc_json: str | None = None,
1032 doc_yaml: str | None = None,
1033 doc_rst: str | None = None,
1034 task_display_name: str | None = None,
1035 logger_name: str | None = None,
1036 allow_nested_operators: bool = True,
1037 **kwargs: Any,
1038 ):
1039 # Note: Metaclass handles passing in the Dag/TaskGroup from active context manager, if any
1040
1041 # Only apply task_group prefix if this operator was not created from a mapped operator
1042 # Mapped operators already have the prefix applied during their creation
1043 if task_group and not self.__from_mapped:
1044 self.task_id = task_group.child_id(task_id)
1045 task_group.add(self)
1046 else:
1047 self.task_id = task_id
1048
1049 super().__init__()
1050 self.task_group = task_group
1051
1052 kwargs.pop("_airflow_mapped_validation_only", None)
1053 if kwargs:
1054 raise TypeError(
1055 f"Invalid arguments were passed to {self.__class__.__name__} (task_id: {task_id}). "
1056 f"Invalid arguments were:\n**kwargs: {redact(kwargs)}",
1057 )
1058 validate_key(self.task_id)
1059
1060 self.owner = owner
1061 self.email = email
1062 self.email_on_retry = email_on_retry
1063 self.email_on_failure = email_on_failure
1064
1065 if email is not None:
1066 warnings.warn(
1067 "Setting email on a task is deprecated; please migrate to SmtpNotifier.",
1068 RemovedInAirflow4Warning,
1069 stacklevel=2,
1070 )
1071 if email and email_on_retry is not None:
1072 warnings.warn(
1073 "Setting email_on_retry on a task is deprecated; please migrate to SmtpNotifier.",
1074 RemovedInAirflow4Warning,
1075 stacklevel=2,
1076 )
1077 if email and email_on_failure is not None:
1078 warnings.warn(
1079 "Setting email_on_failure on a task is deprecated; please migrate to SmtpNotifier.",
1080 RemovedInAirflow4Warning,
1081 stacklevel=2,
1082 )
1083
1084 if execution_timeout is not None and not isinstance(execution_timeout, timedelta):
1085 raise ValueError(
1086 f"execution_timeout must be timedelta object but passed as type: {type(execution_timeout)}"
1087 )
1088 self.execution_timeout = execution_timeout
1089
1090 self.on_execute_callback = _collect_from_input(on_execute_callback)
1091 self.on_failure_callback = _collect_from_input(on_failure_callback)
1092 self.on_success_callback = _collect_from_input(on_success_callback)
1093 self.on_retry_callback = _collect_from_input(on_retry_callback)
1094 self.on_skipped_callback = _collect_from_input(on_skipped_callback)
1095 self._pre_execute_hook = pre_execute
1096 self._post_execute_hook = post_execute
1097
1098 self.start_date = timezone.convert_to_utc(start_date)
1099 self.end_date = timezone.convert_to_utc(end_date)
1100 self.executor = executor
1101 self.executor_config = executor_config or {}
1102 self.run_as_user = run_as_user
1103 # TODO:
1104 # self.retries = parse_retries(retries)
1105 self.retries = retries
1106 self.queue = queue
1107 self.pool = DEFAULT_POOL_NAME if pool is None else pool
1108 self.pool_slots = pool_slots
1109 if self.pool_slots < 1:
1110 dag_str = f" in dag {dag.dag_id}" if dag else ""
1111 raise ValueError(f"pool slots for {self.task_id}{dag_str} cannot be less than 1")
1112 if sla is not None:
1113 warnings.warn(
1114 "The SLA feature is removed in Airflow 3.0, replaced with Deadline Alerts in >=3.1",
1115 stacklevel=2,
1116 )
1117
1118 try:
1119 TriggerRule(trigger_rule)
1120 except ValueError:
1121 raise ValueError(
1122 f"The trigger_rule must be one of {[rule.value for rule in TriggerRule]},"
1123 f"'{dag.dag_id if dag else ''}.{task_id}'; received '{trigger_rule}'."
1124 )
1125
1126 self.trigger_rule: TriggerRule = TriggerRule(trigger_rule)
1127
1128 self.depends_on_past: bool = depends_on_past
1129 self.ignore_first_depends_on_past: bool = ignore_first_depends_on_past
1130 self.wait_for_past_depends_before_skipping: bool = wait_for_past_depends_before_skipping
1131 self.wait_for_downstream: bool = wait_for_downstream
1132 if wait_for_downstream:
1133 self.depends_on_past = True
1134
1135 # Converted by setattr
1136 self.retry_delay = retry_delay # type: ignore[assignment]
1137 self.retry_exponential_backoff = retry_exponential_backoff
1138 if max_retry_delay is not None:
1139 self.max_retry_delay = max_retry_delay
1140
1141 self.resources = resources
1142
1143 self.params = ParamsDict(params)
1144
1145 self.priority_weight = priority_weight
1146 self.weight_rule = validate_and_load_priority_weight_strategy(weight_rule)
1147
1148 self.max_active_tis_per_dag: int | None = max_active_tis_per_dag
1149 self.max_active_tis_per_dagrun: int | None = max_active_tis_per_dagrun
1150 self.do_xcom_push: bool = do_xcom_push
1151 self.map_index_template: str | None = map_index_template
1152 self.multiple_outputs: bool = multiple_outputs
1153
1154 self.doc_md = doc_md
1155 self.doc_json = doc_json
1156 self.doc_yaml = doc_yaml
1157 self.doc_rst = doc_rst
1158 self.doc = doc
1159
1160 self._task_display_name = task_display_name
1161
1162 self.allow_nested_operators = allow_nested_operators
1163
1164 self._logger_name = logger_name
1165
1166 # Lineage
1167 self.inlets = _collect_from_input(inlets)
1168 self.outlets = _collect_from_input(outlets)
1169
1170 if isinstance(self.template_fields, str):
1171 warnings.warn(
1172 f"The `template_fields` value for {self.task_type} is a string "
1173 "but should be a list or tuple of string. Wrapping it in a list for execution. "
1174 f"Please update {self.task_type} accordingly.",
1175 UserWarning,
1176 stacklevel=2,
1177 )
1178 self.template_fields = [self.template_fields]
1179
1180 self.is_setup = False
1181 self.is_teardown = False
1182
1183 if SetupTeardownContext.active:
1184 SetupTeardownContext.update_context_map(self)
1185
1186 # We set self.dag right at the end as `_convert_dag` calls `dag.add_task` for us, and we need all the
1187 # other properties to be set at that point
1188 if dag is not None:
1189 self.dag = dag
1190
1191 validate_instance_args(self, BASEOPERATOR_ARGS_EXPECTED_TYPES)
1192
1193 # Ensure priority_weight is within the valid range
1194 self.priority_weight = db_safe_priority(self.priority_weight)
1195
1196 def __eq__(self, other):
1197 if type(self) is type(other):
1198 # Use getattr() instead of __dict__ as __dict__ doesn't return
1199 # correct values for properties.
1200 return all(getattr(self, c, None) == getattr(other, c, None) for c in self._comps)
1201 return False
1202
1203 def __ne__(self, other):
1204 return not self == other
1205
1206 def __hash__(self):
1207 hash_components = [type(self)]
1208 for component in self._comps:
1209 val = getattr(self, component, None)
1210 try:
1211 hash(val)
1212 hash_components.append(val)
1213 except TypeError:
1214 hash_components.append(repr(val))
1215 return hash(tuple(hash_components))
1216
1217 # /Composing Operators ---------------------------------------------
1218
1219 def __gt__(self, other):
1220 """
1221 Return [Operator] > [Outlet].
1222
1223 If other is an attr annotated object it is set as an outlet of this Operator.
1224 """
1225 if not isinstance(other, Iterable):
1226 other = [other]
1227
1228 for obj in other:
1229 if not attrs.has(obj):
1230 raise TypeError(f"Left hand side ({obj}) is not an outlet")
1231 self.add_outlets(other)
1232
1233 return self
1234
1235 def __lt__(self, other):
1236 """
1237 Return [Inlet] > [Operator] or [Operator] < [Inlet].
1238
1239 If other is an attr annotated object it is set as an inlet to this operator.
1240 """
1241 if not isinstance(other, Iterable):
1242 other = [other]
1243
1244 for obj in other:
1245 if not attrs.has(obj):
1246 raise TypeError(f"{obj} cannot be an inlet")
1247 self.add_inlets(other)
1248
1249 return self
1250
1251 def __deepcopy__(self, memo: dict[int, Any]):
1252 # Hack sorting double chained task lists by task_id to avoid hitting
1253 # max_depth on deepcopy operations.
1254 sys.setrecursionlimit(5000) # TODO fix this in a better way
1255
1256 cls = self.__class__
1257 result = cls.__new__(cls)
1258 memo[id(self)] = result
1259
1260 shallow_copy = tuple(cls.shallow_copy_attrs) + cls._base_operator_shallow_copy_attrs
1261
1262 for k, v_org in self.__dict__.items():
1263 if k not in shallow_copy:
1264 v = copy.deepcopy(v_org, memo)
1265 else:
1266 v = copy.copy(v_org)
1267
1268 # Bypass any setters, and set it on the object directly. This works since we are cloning ourself so
1269 # we know the type is already fine
1270 result.__dict__[k] = v
1271 return result
1272
1273 def __getstate__(self):
1274 state = dict(self.__dict__)
1275 if "_log" in state:
1276 del state["_log"]
1277
1278 return state
1279
1280 def __setstate__(self, state):
1281 self.__dict__ = state
1282
1283 def add_inlets(self, inlets: Iterable[Any]):
1284 """Set inlets to this operator."""
1285 self.inlets.extend(inlets)
1286
1287 def add_outlets(self, outlets: Iterable[Any]):
1288 """Define the outlets of this operator."""
1289 self.outlets.extend(outlets)
1290
1291 def get_dag(self) -> DAG | None:
1292 return self._dag
1293
1294 @property
1295 def dag(self) -> DAG:
1296 """Returns the Operator's Dag if set, otherwise raises an error."""
1297 if dag := self._dag:
1298 return dag
1299 raise RuntimeError(f"Operator {self} has not been assigned to a Dag yet")
1300
1301 @dag.setter
1302 def dag(self, dag: DAG | None) -> None:
1303 """Operators can be assigned to one Dag, one time. Repeat assignments to that same Dag are ok."""
1304 self._dag = dag
1305
1306 def _convert__dag(self, dag: DAG | None) -> DAG | None:
1307 # Called automatically by __setattr__ method
1308 from airflow.sdk.definitions.dag import DAG
1309
1310 if dag is None:
1311 return dag
1312
1313 if not isinstance(dag, DAG):
1314 raise TypeError(f"Expected dag; received {dag.__class__.__name__}")
1315 if self._dag is not None and self._dag is not dag:
1316 raise ValueError(f"The dag assigned to {self} can not be changed.")
1317
1318 if self.__from_mapped:
1319 pass # Don't add to dag -- the mapped task takes the place.
1320 elif dag.task_dict.get(self.task_id) is not self:
1321 dag.add_task(self)
1322 return dag
1323
1324 @staticmethod
1325 def _convert_retries(retries: Any) -> int | None:
1326 if retries is None:
1327 return 0
1328 if type(retries) == int: # noqa: E721
1329 return retries
1330 try:
1331 parsed_retries = int(retries)
1332 except (TypeError, ValueError):
1333 raise TypeError(f"'retries' type must be int, not {type(retries).__name__}")
1334 return parsed_retries
1335
1336 @staticmethod
1337 def _convert_timedelta(value: float | timedelta | None) -> timedelta | None:
1338 if value is None or isinstance(value, timedelta):
1339 return value
1340 return timedelta(seconds=value)
1341
1342 _convert_retry_delay = _convert_timedelta
1343 _convert_max_retry_delay = _convert_timedelta
1344
1345 @staticmethod
1346 def _convert_resources(resources: dict[str, Any] | None) -> Resources | None:
1347 if resources is None:
1348 return None
1349
1350 from airflow.sdk.definitions.operator_resources import Resources
1351
1352 if isinstance(resources, Resources):
1353 return resources
1354
1355 return Resources(**resources)
1356
1357 def _convert_is_setup(self, value: bool) -> bool:
1358 """
1359 Setter for is_setup property.
1360
1361 :meta private:
1362 """
1363 if self.is_teardown and value:
1364 raise ValueError(f"Cannot mark task '{self.task_id}' as setup; task is already a teardown.")
1365 return value
1366
1367 def _convert_is_teardown(self, value: bool) -> bool:
1368 if self.is_setup and value:
1369 raise ValueError(f"Cannot mark task '{self.task_id}' as teardown; task is already a setup.")
1370 return value
1371
1372 @property
1373 def task_display_name(self) -> str:
1374 return self._task_display_name or self.task_id
1375
1376 def has_dag(self):
1377 """Return True if the Operator has been assigned to a Dag."""
1378 return self._dag is not None
1379
1380 def _set_xcomargs_dependencies(self) -> None:
1381 from airflow.sdk.definitions.xcom_arg import XComArg
1382
1383 for f in self.template_fields:
1384 arg = getattr(self, f, NOTSET)
1385 if arg is not NOTSET:
1386 XComArg.apply_upstream_relationship(self, arg)
1387
1388 def _set_xcomargs_dependency(self, field: str, newvalue: Any) -> None:
1389 """
1390 Resolve upstream dependencies of a task.
1391
1392 In this way passing an ``XComArg`` as value for a template field
1393 will result in creating upstream relation between two tasks.
1394
1395 **Example**: ::
1396
1397 with DAG(...):
1398 generate_content = GenerateContentOperator(task_id="generate_content")
1399 send_email = EmailOperator(..., html_content=generate_content.output)
1400
1401 # This is equivalent to
1402 with DAG(...):
1403 generate_content = GenerateContentOperator(task_id="generate_content")
1404 send_email = EmailOperator(
1405 ..., html_content="{{ task_instance.xcom_pull('generate_content') }}"
1406 )
1407 generate_content >> send_email
1408
1409 """
1410 from airflow.sdk.definitions.xcom_arg import XComArg
1411
1412 if field not in self.template_fields:
1413 return
1414 XComArg.apply_upstream_relationship(self, newvalue)
1415
1416 def on_kill(self) -> None:
1417 """
1418 Override this method to clean up subprocesses when a task instance gets killed.
1419
1420 Any use of the threading, subprocess or multiprocessing module within an
1421 operator needs to be cleaned up, or it will leave ghost processes behind.
1422 """
1423
1424 def __repr__(self):
1425 return f"<Task({self.task_type}): {self.task_id}>"
1426
1427 @property
1428 def operator_class(self) -> type[BaseOperator]: # type: ignore[override]
1429 return self.__class__
1430
1431 @property
1432 def task_type(self) -> str:
1433 """@property: type of the task."""
1434 return self.__class__.__name__
1435
1436 @property
1437 def operator_name(self) -> str:
1438 """@property: use a more friendly display name for the operator, if set."""
1439 try:
1440 return self.custom_operator_name # type: ignore
1441 except AttributeError:
1442 return self.task_type
1443
1444 @property
1445 def roots(self) -> list[BaseOperator]:
1446 """Required by DAGNode."""
1447 return [self]
1448
1449 @property
1450 def leaves(self) -> list[BaseOperator]:
1451 """Required by DAGNode."""
1452 return [self]
1453
1454 @property
1455 def output(self) -> XComArg:
1456 """Returns reference to XCom pushed by current operator."""
1457 from airflow.sdk.definitions.xcom_arg import XComArg
1458
1459 return XComArg(operator=self)
1460
1461 @classmethod
1462 def get_serialized_fields(cls):
1463 """Stringified Dags and operators contain exactly these fields."""
1464 if not cls.__serialized_fields:
1465 from airflow.sdk.definitions._internal.contextmanager import DagContext
1466
1467 # make sure the following "fake" task is not added to current active
1468 # dag in context, otherwise, it will result in
1469 # `RuntimeError: dictionary changed size during iteration`
1470 # Exception in SerializedDAG.serialize_dag() call.
1471 DagContext.push(None)
1472 cls.__serialized_fields = frozenset(
1473 vars(BaseOperator(task_id="test")).keys()
1474 - {
1475 "upstream_task_ids",
1476 "default_args",
1477 "dag",
1478 "_dag",
1479 "label",
1480 "_BaseOperator__instantiated",
1481 "_BaseOperator__init_kwargs",
1482 "_BaseOperator__from_mapped",
1483 "on_failure_fail_dagrun",
1484 "task_group",
1485 "_task_type",
1486 "operator_extra_links",
1487 "on_execute_callback",
1488 "on_failure_callback",
1489 "on_success_callback",
1490 "on_retry_callback",
1491 "on_skipped_callback",
1492 }
1493 | { # Class level defaults, or `@property` need to be added to this list
1494 "start_date",
1495 "end_date",
1496 "task_type",
1497 "ui_color",
1498 "ui_fgcolor",
1499 "template_ext",
1500 "template_fields",
1501 "template_fields_renderers",
1502 "params",
1503 "is_setup",
1504 "is_teardown",
1505 "on_failure_fail_dagrun",
1506 "map_index_template",
1507 "start_trigger_args",
1508 "_needs_expansion",
1509 "start_from_trigger",
1510 "max_retry_delay",
1511 "has_on_execute_callback",
1512 "has_on_failure_callback",
1513 "has_on_success_callback",
1514 "has_on_retry_callback",
1515 "has_on_skipped_callback",
1516 }
1517 )
1518 DagContext.pop()
1519
1520 return cls.__serialized_fields
1521
1522 def prepare_for_execution(self) -> Self:
1523 """Lock task for execution to disable custom action in ``__setattr__`` and return a copy."""
1524 other = copy.copy(self)
1525 other._lock_for_execution = True
1526 return other
1527
1528 def serialize_for_task_group(self) -> tuple[DagAttributeTypes, Any]:
1529 """Serialize; required by DAGNode."""
1530 from airflow.serialization.enums import DagAttributeTypes
1531
1532 return DagAttributeTypes.OP, self.task_id
1533
1534 def unmap(self, resolve: None | Mapping[str, Any]) -> Self:
1535 """
1536 Get the "normal" operator from the current operator.
1537
1538 Since a BaseOperator is not mapped to begin with, this simply returns
1539 the original operator.
1540
1541 :meta private:
1542 """
1543 return self
1544
1545 def expand_start_trigger_args(self, *, context: Context) -> StartTriggerArgs | None:
1546 """
1547 Get the start_trigger_args value of the current abstract operator.
1548
1549 Since a BaseOperator is not mapped to begin with, this simply returns
1550 the original value of start_trigger_args.
1551
1552 :meta private:
1553 """
1554 return self.start_trigger_args
1555
1556 def render_template_fields(
1557 self,
1558 context: Context,
1559 jinja_env: jinja2.Environment | None = None,
1560 ) -> None:
1561 """
1562 Template all attributes listed in *self.template_fields*.
1563
1564 This mutates the attributes in-place and is irreversible.
1565
1566 :param context: Context dict with values to apply on content.
1567 :param jinja_env: Jinja's environment to use for rendering.
1568 """
1569 if not jinja_env:
1570 jinja_env = self.get_template_env()
1571 self._do_render_template_fields(self, self.template_fields, context, jinja_env, set())
1572
1573 def pre_execute(self, context: Any):
1574 """Execute right before self.execute() is called."""
1575
1576 def execute(self, context: Context) -> Any:
1577 """
1578 Derive when creating an operator.
1579
1580 The main method to execute the task. Context is the same dictionary used
1581 as when rendering jinja templates.
1582
1583 Refer to get_template_context for more context.
1584 """
1585 raise NotImplementedError()
1586
1587 def post_execute(self, context: Any, result: Any = None):
1588 """
1589 Execute right after self.execute() is called.
1590
1591 It is passed the execution context and any results returned by the operator.
1592 """
1593
1594 def defer(
1595 self,
1596 *,
1597 trigger: BaseTrigger,
1598 method_name: str,
1599 kwargs: dict[str, Any] | None = None,
1600 timeout: timedelta | int | float | None = None,
1601 ) -> NoReturn:
1602 """
1603 Mark this Operator "deferred", suspending its execution until the provided trigger fires an event.
1604
1605 This is achieved by raising a special exception (TaskDeferred)
1606 which is caught in the main _execute_task wrapper. Triggers can send execution back to task or end
1607 the task instance directly. If the trigger will end the task instance itself, ``method_name`` should
1608 be None; otherwise, provide the name of the method that should be used when resuming execution in
1609 the task.
1610 """
1611 from airflow.sdk.exceptions import TaskDeferred
1612
1613 raise TaskDeferred(trigger=trigger, method_name=method_name, kwargs=kwargs, timeout=timeout)
1614
1615 def resume_execution(self, next_method: str, next_kwargs: dict[str, Any] | None, context: Context):
1616 """Entrypoint method called by the Task Runner (instead of execute) when this task is resumed."""
1617 from airflow.sdk.exceptions import TaskDeferralError, TaskDeferralTimeout
1618
1619 if next_kwargs is None:
1620 next_kwargs = {}
1621 # __fail__ is a special signal value for next_method that indicates
1622 # this task was scheduled specifically to fail.
1623
1624 if next_method == TRIGGER_FAIL_REPR:
1625 next_kwargs = next_kwargs or {}
1626 traceback = next_kwargs.get("traceback")
1627 if traceback is not None:
1628 self.log.error("Trigger failed:\n%s", "\n".join(traceback))
1629 if (error := next_kwargs.get("error", "Unknown")) == TriggerFailureReason.TRIGGER_TIMEOUT:
1630 raise TaskDeferralTimeout(error)
1631 raise TaskDeferralError(error)
1632 # Grab the callable off the Operator/Task and add in any kwargs
1633 execute_callable = getattr(self, next_method)
1634 return execute_callable(context, **next_kwargs)
1635
1636 def dry_run(self) -> None:
1637 """Perform dry run for the operator - just render template fields."""
1638 self.log.info("Dry run")
1639 for f in self.template_fields:
1640 try:
1641 content = getattr(self, f)
1642 except AttributeError:
1643 raise AttributeError(
1644 f"{f!r} is configured as a template field "
1645 f"but {self.task_type} does not have this attribute."
1646 )
1647
1648 if content and isinstance(content, str):
1649 self.log.info("Rendering template for %s", f)
1650 self.log.info(content)
1651
1652 @property
1653 def has_on_execute_callback(self) -> bool:
1654 """Return True if the task has execute callbacks."""
1655 return bool(self.on_execute_callback)
1656
1657 @property
1658 def has_on_failure_callback(self) -> bool:
1659 """Return True if the task has failure callbacks."""
1660 return bool(self.on_failure_callback)
1661
1662 @property
1663 def has_on_success_callback(self) -> bool:
1664 """Return True if the task has success callbacks."""
1665 return bool(self.on_success_callback)
1666
1667 @property
1668 def has_on_retry_callback(self) -> bool:
1669 """Return True if the task has retry callbacks."""
1670 return bool(self.on_retry_callback)
1671
1672 @property
1673 def has_on_skipped_callback(self) -> bool:
1674 """Return True if the task has skipped callbacks."""
1675 return bool(self.on_skipped_callback)
1676
1677
1678def chain(*tasks: DependencyMixin | Sequence[DependencyMixin]) -> None:
1679 r"""
1680 Given a number of tasks, builds a dependency chain.
1681
1682 This function accepts values of BaseOperator (aka tasks), EdgeModifiers (aka Labels), XComArg, TaskGroups,
1683 or lists containing any mix of these types (or a mix in the same list). If you want to chain between two
1684 lists you must ensure they have the same length.
1685
1686 Using classic operators/sensors:
1687
1688 .. code-block:: python
1689
1690 chain(t1, [t2, t3], [t4, t5], t6)
1691
1692 is equivalent to::
1693
1694 / -> t2 -> t4 \
1695 t1 -> t6
1696 \ -> t3 -> t5 /
1697
1698 .. code-block:: python
1699
1700 t1.set_downstream(t2)
1701 t1.set_downstream(t3)
1702 t2.set_downstream(t4)
1703 t3.set_downstream(t5)
1704 t4.set_downstream(t6)
1705 t5.set_downstream(t6)
1706
1707 Using task-decorated functions aka XComArgs:
1708
1709 .. code-block:: python
1710
1711 chain(x1(), [x2(), x3()], [x4(), x5()], x6())
1712
1713 is equivalent to::
1714
1715 / -> x2 -> x4 \
1716 x1 -> x6
1717 \ -> x3 -> x5 /
1718
1719 .. code-block:: python
1720
1721 x1 = x1()
1722 x2 = x2()
1723 x3 = x3()
1724 x4 = x4()
1725 x5 = x5()
1726 x6 = x6()
1727 x1.set_downstream(x2)
1728 x1.set_downstream(x3)
1729 x2.set_downstream(x4)
1730 x3.set_downstream(x5)
1731 x4.set_downstream(x6)
1732 x5.set_downstream(x6)
1733
1734 Using TaskGroups:
1735
1736 .. code-block:: python
1737
1738 chain(t1, task_group1, task_group2, t2)
1739
1740 t1.set_downstream(task_group1)
1741 task_group1.set_downstream(task_group2)
1742 task_group2.set_downstream(t2)
1743
1744
1745 It is also possible to mix between classic operator/sensor, EdgeModifiers, XComArg, and TaskGroups:
1746
1747 .. code-block:: python
1748
1749 chain(t1, [Label("branch one"), Label("branch two")], [x1(), x2()], task_group1, x3())
1750
1751 is equivalent to::
1752
1753 / "branch one" -> x1 \
1754 t1 -> task_group1 -> x3
1755 \ "branch two" -> x2 /
1756
1757 .. code-block:: python
1758
1759 x1 = x1()
1760 x2 = x2()
1761 x3 = x3()
1762 label1 = Label("branch one")
1763 label2 = Label("branch two")
1764 t1.set_downstream(label1)
1765 label1.set_downstream(x1)
1766 t2.set_downstream(label2)
1767 label2.set_downstream(x2)
1768 x1.set_downstream(task_group1)
1769 x2.set_downstream(task_group1)
1770 task_group1.set_downstream(x3)
1771
1772 # or
1773
1774 x1 = x1()
1775 x2 = x2()
1776 x3 = x3()
1777 t1.set_downstream(x1, edge_modifier=Label("branch one"))
1778 t1.set_downstream(x2, edge_modifier=Label("branch two"))
1779 x1.set_downstream(task_group1)
1780 x2.set_downstream(task_group1)
1781 task_group1.set_downstream(x3)
1782
1783
1784 :param tasks: Individual and/or list of tasks, EdgeModifiers, XComArgs, or TaskGroups to set dependencies
1785 """
1786 for up_task, down_task in zip(tasks, tasks[1:]):
1787 if isinstance(up_task, DependencyMixin):
1788 up_task.set_downstream(down_task)
1789 continue
1790 if isinstance(down_task, DependencyMixin):
1791 down_task.set_upstream(up_task)
1792 continue
1793 if not isinstance(up_task, Sequence) or not isinstance(down_task, Sequence):
1794 raise TypeError(f"Chain not supported between instances of {type(up_task)} and {type(down_task)}")
1795 up_task_list = up_task
1796 down_task_list = down_task
1797 if len(up_task_list) != len(down_task_list):
1798 raise ValueError(
1799 f"Chain not supported for different length Iterable. "
1800 f"Got {len(up_task_list)} and {len(down_task_list)}."
1801 )
1802 for up_t, down_t in zip(up_task_list, down_task_list):
1803 up_t.set_downstream(down_t)
1804
1805
1806def cross_downstream(
1807 from_tasks: Sequence[DependencyMixin],
1808 to_tasks: DependencyMixin | Sequence[DependencyMixin],
1809):
1810 r"""
1811 Set downstream dependencies for all tasks in from_tasks to all tasks in to_tasks.
1812
1813 Using classic operators/sensors:
1814
1815 .. code-block:: python
1816
1817 cross_downstream(from_tasks=[t1, t2, t3], to_tasks=[t4, t5, t6])
1818
1819 is equivalent to::
1820
1821 t1 ---> t4
1822 \ /
1823 t2 -X -> t5
1824 / \
1825 t3 ---> t6
1826
1827 .. code-block:: python
1828
1829 t1.set_downstream(t4)
1830 t1.set_downstream(t5)
1831 t1.set_downstream(t6)
1832 t2.set_downstream(t4)
1833 t2.set_downstream(t5)
1834 t2.set_downstream(t6)
1835 t3.set_downstream(t4)
1836 t3.set_downstream(t5)
1837 t3.set_downstream(t6)
1838
1839 Using task-decorated functions aka XComArgs:
1840
1841 .. code-block:: python
1842
1843 cross_downstream(from_tasks=[x1(), x2(), x3()], to_tasks=[x4(), x5(), x6()])
1844
1845 is equivalent to::
1846
1847 x1 ---> x4
1848 \ /
1849 x2 -X -> x5
1850 / \
1851 x3 ---> x6
1852
1853 .. code-block:: python
1854
1855 x1 = x1()
1856 x2 = x2()
1857 x3 = x3()
1858 x4 = x4()
1859 x5 = x5()
1860 x6 = x6()
1861 x1.set_downstream(x4)
1862 x1.set_downstream(x5)
1863 x1.set_downstream(x6)
1864 x2.set_downstream(x4)
1865 x2.set_downstream(x5)
1866 x2.set_downstream(x6)
1867 x3.set_downstream(x4)
1868 x3.set_downstream(x5)
1869 x3.set_downstream(x6)
1870
1871 It is also possible to mix between classic operator/sensor and XComArg tasks:
1872
1873 .. code-block:: python
1874
1875 cross_downstream(from_tasks=[t1, x2(), t3], to_tasks=[x1(), t2, x3()])
1876
1877 is equivalent to::
1878
1879 t1 ---> x1
1880 \ /
1881 x2 -X -> t2
1882 / \
1883 t3 ---> x3
1884
1885 .. code-block:: python
1886
1887 x1 = x1()
1888 x2 = x2()
1889 x3 = x3()
1890 t1.set_downstream(x1)
1891 t1.set_downstream(t2)
1892 t1.set_downstream(x3)
1893 x2.set_downstream(x1)
1894 x2.set_downstream(t2)
1895 x2.set_downstream(x3)
1896 t3.set_downstream(x1)
1897 t3.set_downstream(t2)
1898 t3.set_downstream(x3)
1899
1900 :param from_tasks: List of tasks or XComArgs to start from.
1901 :param to_tasks: List of tasks or XComArgs to set as downstream dependencies.
1902 """
1903 for task in from_tasks:
1904 task.set_downstream(to_tasks)
1905
1906
1907def chain_linear(*elements: DependencyMixin | Sequence[DependencyMixin]):
1908 """
1909 Simplify task dependency definition.
1910
1911 E.g.: suppose you want precedence like so::
1912
1913 ╭─op2─╮ ╭─op4─╮
1914 op1─┤ ├─├─op5─┤─op7
1915 ╰-op3─╯ ╰-op6─╯
1916
1917 Then you can accomplish like so::
1918
1919 chain_linear(op1, [op2, op3], [op4, op5, op6], op7)
1920
1921 :param elements: a list of operators / lists of operators
1922 """
1923 if not elements:
1924 raise ValueError("No tasks provided; nothing to do.")
1925 prev_elem = None
1926 deps_set = False
1927 for curr_elem in elements:
1928 if isinstance(curr_elem, EdgeModifier):
1929 raise ValueError("Labels are not supported by chain_linear")
1930 if prev_elem is not None:
1931 for task in prev_elem:
1932 task >> curr_elem
1933 if not deps_set:
1934 deps_set = True
1935 prev_elem = [curr_elem] if isinstance(curr_elem, DependencyMixin) else curr_elem
1936 if not deps_set:
1937 raise ValueError("No dependencies were set. Did you forget to expand with `*`?")