1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements. See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership. The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License. You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied. See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18from __future__ import annotations
19
20import abc
21import collections.abc
22import contextlib
23import copy
24import inspect
25import sys
26import warnings
27from collections.abc import Callable, Collection, Iterable, Mapping, Sequence
28from contextvars import ContextVar
29from dataclasses import dataclass, field
30from datetime import datetime, timedelta
31from enum import Enum
32from functools import total_ordering, wraps
33from types import FunctionType
34from typing import TYPE_CHECKING, Any, ClassVar, Final, NoReturn, TypeVar, cast
35
36import attrs
37
38from airflow.sdk import TriggerRule, timezone
39from airflow.sdk._shared.secrets_masker import redact
40from airflow.sdk.definitions._internal.abstractoperator import (
41 DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST,
42 DEFAULT_OWNER,
43 DEFAULT_POOL_NAME,
44 DEFAULT_POOL_SLOTS,
45 DEFAULT_PRIORITY_WEIGHT,
46 DEFAULT_QUEUE,
47 DEFAULT_RETRIES,
48 DEFAULT_RETRY_DELAY,
49 DEFAULT_TASK_EXECUTION_TIMEOUT,
50 DEFAULT_TRIGGER_RULE,
51 DEFAULT_WAIT_FOR_PAST_DEPENDS_BEFORE_SKIPPING,
52 DEFAULT_WEIGHT_RULE,
53 AbstractOperator,
54 DependencyMixin,
55 TaskStateChangeCallback,
56)
57from airflow.sdk.definitions._internal.decorators import fixup_decorator_warning_stack
58from airflow.sdk.definitions._internal.node import validate_key
59from airflow.sdk.definitions._internal.setup_teardown import SetupTeardownContext
60from airflow.sdk.definitions._internal.types import NOTSET, validate_instance_args
61from airflow.sdk.definitions.edges import EdgeModifier
62from airflow.sdk.definitions.mappedoperator import OperatorPartial, validate_mapping_kwargs
63from airflow.sdk.definitions.param import ParamsDict
64from airflow.sdk.exceptions import RemovedInAirflow4Warning
65from airflow.task.priority_strategy import (
66 PriorityWeightStrategy,
67 airflow_priority_weight_strategies,
68 validate_and_load_priority_weight_strategy,
69)
70
71# Databases do not support arbitrary precision integers, so we need to limit the range of priority weights.
72# postgres: -2147483648 to +2147483647 (see https://www.postgresql.org/docs/current/datatype-numeric.html)
73# mysql: -2147483648 to +2147483647 (see https://dev.mysql.com/doc/refman/8.4/en/integer-types.html)
74# sqlite: -9223372036854775808 to +9223372036854775807 (see https://sqlite.org/datatype3.html)
75DB_SAFE_MINIMUM = -2147483648
76DB_SAFE_MAXIMUM = 2147483647
77
78
79def db_safe_priority(priority_weight: int) -> int:
80 """Convert priority weight to a safe value for the database."""
81 return max(DB_SAFE_MINIMUM, min(DB_SAFE_MAXIMUM, priority_weight))
82
83
84C = TypeVar("C", bound=Callable)
85T = TypeVar("T", bound=FunctionType)
86
87if TYPE_CHECKING:
88 from types import ClassMethodDescriptorType
89
90 import jinja2
91 from typing_extensions import Self
92
93 from airflow.sdk.bases.operatorlink import BaseOperatorLink
94 from airflow.sdk.definitions.context import Context
95 from airflow.sdk.definitions.dag import DAG
96 from airflow.sdk.definitions.operator_resources import Resources
97 from airflow.sdk.definitions.taskgroup import TaskGroup
98 from airflow.sdk.definitions.xcom_arg import XComArg
99 from airflow.serialization.enums import DagAttributeTypes
100 from airflow.task.priority_strategy import PriorityWeightStrategy
101 from airflow.triggers.base import BaseTrigger, StartTriggerArgs
102
103 TaskPreExecuteHook = Callable[[Context], None]
104 TaskPostExecuteHook = Callable[[Context, Any], None]
105
106__all__ = [
107 "BaseOperator",
108 "chain",
109 "chain_linear",
110 "cross_downstream",
111]
112
113
114class TriggerFailureReason(str, Enum):
115 """
116 Reasons for trigger failures.
117
118 Internal use only.
119
120 :meta private:
121 """
122
123 TRIGGER_TIMEOUT = "Trigger timeout"
124 TRIGGER_FAILURE = "Trigger failure"
125
126
127TRIGGER_FAIL_REPR = "__fail__"
128"""String value to represent trigger failure.
129
130Internal use only.
131
132:meta private:
133"""
134
135
136def _get_parent_defaults(dag: DAG | None, task_group: TaskGroup | None) -> tuple[dict, ParamsDict]:
137 if not dag:
138 return {}, ParamsDict()
139 dag_args = copy.copy(dag.default_args)
140 dag_params = copy.deepcopy(dag.params)
141 dag_params._fill_missing_param_source("dag")
142 if task_group:
143 if task_group.default_args and not isinstance(task_group.default_args, collections.abc.Mapping):
144 raise TypeError("default_args must be a mapping")
145 dag_args.update(task_group.default_args)
146 return dag_args, dag_params
147
148
149def get_merged_defaults(
150 dag: DAG | None,
151 task_group: TaskGroup | None,
152 task_params: collections.abc.MutableMapping | None,
153 task_default_args: dict | None,
154) -> tuple[dict, ParamsDict]:
155 args, params = _get_parent_defaults(dag, task_group)
156 if task_params:
157 if not isinstance(task_params, collections.abc.Mapping):
158 raise TypeError(f"params must be a mapping, got {type(task_params)}")
159
160 task_params = ParamsDict(task_params)
161 task_params._fill_missing_param_source("task")
162 params.update(task_params)
163
164 if task_default_args:
165 if not isinstance(task_default_args, collections.abc.Mapping):
166 raise TypeError(f"default_args must be a mapping, got {type(task_params)}")
167 args.update(task_default_args)
168 with contextlib.suppress(KeyError):
169 if params_from_default_args := ParamsDict(task_default_args["params"] or {}):
170 params_from_default_args._fill_missing_param_source("task")
171 params.update(params_from_default_args)
172
173 return args, params
174
175
176def parse_retries(retries: Any) -> int | None:
177 if retries is None:
178 return 0
179 if type(retries) == int: # noqa: E721
180 return retries
181 try:
182 parsed_retries = int(retries)
183 except (TypeError, ValueError):
184 raise RuntimeError(f"'retries' type must be int, not {type(retries).__name__}")
185 return parsed_retries
186
187
188def coerce_timedelta(value: float | timedelta, *, key: str | None = None) -> timedelta:
189 if isinstance(value, timedelta):
190 return value
191 return timedelta(seconds=value)
192
193
194def coerce_resources(resources: dict[str, Any] | None) -> Resources | None:
195 if resources is None:
196 return None
197 from airflow.sdk.definitions.operator_resources import Resources
198
199 return Resources(**resources)
200
201
202class _PartialDescriptor:
203 """A descriptor that guards against ``.partial`` being called on Task objects."""
204
205 class_method: ClassMethodDescriptorType | None = None
206
207 def __get__(
208 self, obj: BaseOperator, cls: type[BaseOperator] | None = None
209 ) -> Callable[..., OperatorPartial]:
210 # Call this "partial" so it looks nicer in stack traces.
211 def partial(**kwargs):
212 raise TypeError("partial can only be called on Operator classes, not Tasks themselves")
213
214 if obj is not None:
215 return partial
216 return self.class_method.__get__(cls, cls)
217
218
219OPERATOR_DEFAULTS: dict[str, Any] = {
220 "allow_nested_operators": True,
221 "depends_on_past": False,
222 "email_on_failure": True,
223 "email_on_retry": True,
224 "execution_timeout": DEFAULT_TASK_EXECUTION_TIMEOUT,
225 # "executor": DEFAULT_EXECUTOR,
226 "executor_config": {},
227 "ignore_first_depends_on_past": DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST,
228 "inlets": [],
229 "map_index_template": None,
230 "on_execute_callback": [],
231 "on_failure_callback": [],
232 "on_retry_callback": [],
233 "on_skipped_callback": [],
234 "on_success_callback": [],
235 "outlets": [],
236 "owner": DEFAULT_OWNER,
237 "pool_slots": DEFAULT_POOL_SLOTS,
238 "priority_weight": DEFAULT_PRIORITY_WEIGHT,
239 "queue": DEFAULT_QUEUE,
240 "retries": DEFAULT_RETRIES,
241 "retry_delay": DEFAULT_RETRY_DELAY,
242 "retry_exponential_backoff": 0,
243 "trigger_rule": DEFAULT_TRIGGER_RULE,
244 "wait_for_past_depends_before_skipping": DEFAULT_WAIT_FOR_PAST_DEPENDS_BEFORE_SKIPPING,
245 "wait_for_downstream": False,
246 "weight_rule": DEFAULT_WEIGHT_RULE,
247}
248
249
250# This is what handles the actual mapping.
251
252if TYPE_CHECKING:
253
254 def partial(
255 operator_class: type[BaseOperator],
256 *,
257 task_id: str,
258 dag: DAG | None = None,
259 task_group: TaskGroup | None = None,
260 start_date: datetime = ...,
261 end_date: datetime = ...,
262 owner: str = ...,
263 email: None | str | Iterable[str] = ...,
264 params: collections.abc.MutableMapping | None = None,
265 resources: dict[str, Any] | None = ...,
266 trigger_rule: str = ...,
267 depends_on_past: bool = ...,
268 ignore_first_depends_on_past: bool = ...,
269 wait_for_past_depends_before_skipping: bool = ...,
270 wait_for_downstream: bool = ...,
271 retries: int | None = ...,
272 queue: str = ...,
273 pool: str = ...,
274 pool_slots: int = ...,
275 execution_timeout: timedelta | None = ...,
276 max_retry_delay: None | timedelta | float = ...,
277 retry_delay: timedelta | float = ...,
278 retry_exponential_backoff: float = ...,
279 priority_weight: int = ...,
280 weight_rule: str | PriorityWeightStrategy = ...,
281 sla: timedelta | None = ...,
282 map_index_template: str | None = ...,
283 max_active_tis_per_dag: int | None = ...,
284 max_active_tis_per_dagrun: int | None = ...,
285 on_execute_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = ...,
286 on_failure_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = ...,
287 on_success_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = ...,
288 on_retry_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = ...,
289 on_skipped_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = ...,
290 run_as_user: str | None = ...,
291 executor: str | None = ...,
292 executor_config: dict | None = ...,
293 inlets: Any | None = ...,
294 outlets: Any | None = ...,
295 doc: str | None = ...,
296 doc_md: str | None = ...,
297 doc_json: str | None = ...,
298 doc_yaml: str | None = ...,
299 doc_rst: str | None = ...,
300 task_display_name: str | None = ...,
301 logger_name: str | None = ...,
302 allow_nested_operators: bool = True,
303 **kwargs,
304 ) -> OperatorPartial: ...
305else:
306
307 def partial(
308 operator_class: type[BaseOperator],
309 *,
310 task_id: str,
311 dag: DAG | None = None,
312 task_group: TaskGroup | None = None,
313 params: collections.abc.MutableMapping | None = None,
314 **kwargs,
315 ):
316 from airflow.sdk.definitions._internal.contextmanager import DagContext, TaskGroupContext
317
318 validate_mapping_kwargs(operator_class, "partial", kwargs)
319
320 dag = dag or DagContext.get_current()
321 if dag:
322 task_group = task_group or TaskGroupContext.get_current(dag)
323 if task_group:
324 task_id = task_group.child_id(task_id)
325
326 # Merge Dag and task group level defaults into user-supplied values.
327 dag_default_args, partial_params = get_merged_defaults(
328 dag=dag,
329 task_group=task_group,
330 task_params=params,
331 task_default_args=kwargs.pop("default_args", None),
332 )
333
334 # Create partial_kwargs from args and kwargs
335 partial_kwargs: dict[str, Any] = {
336 "task_id": task_id,
337 "dag": dag,
338 "task_group": task_group,
339 **kwargs,
340 }
341
342 # Inject Dag-level default args into args provided to this function.
343 # Most of the default args will be retrieved during unmapping; here we
344 # only ensure base properties are correctly set for the scheduler.
345 partial_kwargs.update(
346 (k, v)
347 for k, v in dag_default_args.items()
348 if k not in partial_kwargs and k in BaseOperator.__init__._BaseOperatorMeta__param_names
349 )
350
351 # Fill fields not provided by the user with default values.
352 partial_kwargs.update((k, v) for k, v in OPERATOR_DEFAULTS.items() if k not in partial_kwargs)
353
354 # Post-process arguments. Should be kept in sync with _TaskDecorator.expand().
355 if "task_concurrency" in kwargs: # Reject deprecated option.
356 raise TypeError("unexpected argument: task_concurrency")
357 if start_date := partial_kwargs.get("start_date", None):
358 partial_kwargs["start_date"] = timezone.convert_to_utc(start_date)
359 if end_date := partial_kwargs.get("end_date", None):
360 partial_kwargs["end_date"] = timezone.convert_to_utc(end_date)
361 if partial_kwargs["pool_slots"] < 1:
362 dag_str = ""
363 if dag:
364 dag_str = f" in dag {dag.dag_id}"
365 raise ValueError(f"pool slots for {task_id}{dag_str} cannot be less than 1")
366 if retries := partial_kwargs.get("retries"):
367 partial_kwargs["retries"] = BaseOperator._convert_retries(retries)
368 partial_kwargs["retry_delay"] = BaseOperator._convert_retry_delay(partial_kwargs["retry_delay"])
369 partial_kwargs["max_retry_delay"] = BaseOperator._convert_max_retry_delay(
370 partial_kwargs.get("max_retry_delay", None)
371 )
372
373 for k in ("execute", "failure", "success", "retry", "skipped"):
374 partial_kwargs[attr] = _collect_from_input(partial_kwargs.get(attr := f"on_{k}_callback"))
375
376 return OperatorPartial(
377 operator_class=operator_class,
378 kwargs=partial_kwargs,
379 params=partial_params,
380 )
381
382
383class ExecutorSafeguard:
384 """
385 The ExecutorSafeguard decorator.
386
387 Checks if the execute method of an operator isn't manually called outside
388 the TaskInstance as we want to avoid bad mixing between decorated and
389 classic operators.
390 """
391
392 test_mode: ClassVar[bool] = False
393 tracker: ClassVar[ContextVar[BaseOperator]] = ContextVar("ExecutorSafeguard_sentinel")
394 sentinel_value: ClassVar[object] = object()
395
396 @classmethod
397 def decorator(cls, func):
398 @wraps(func)
399 def wrapper(self, *args, **kwargs):
400 sentinel_key = f"{self.__class__.__name__}__sentinel"
401 sentinel = kwargs.pop(sentinel_key, None)
402
403 with contextlib.ExitStack() as stack:
404 if sentinel is cls.sentinel_value:
405 token = cls.tracker.set(self)
406 sentinel = self
407 stack.callback(cls.tracker.reset, token)
408 else:
409 # No sentinel passed in, maybe the subclass execute did have it passed?
410 sentinel = cls.tracker.get(None)
411
412 if not cls.test_mode and sentinel is not self:
413 message = f"{self.__class__.__name__}.{func.__name__} cannot be called outside of the Task Runner!"
414 if not self.allow_nested_operators:
415 raise RuntimeError(message)
416 self.log.warning(message)
417
418 # Now that we've logged, set sentinel so that `super()` calls don't log again
419 token = cls.tracker.set(self)
420 stack.callback(cls.tracker.reset, token)
421
422 return func(self, *args, **kwargs)
423
424 return wrapper
425
426
427if "airflow.configuration" in sys.modules:
428 # Don't try and import it if it's not already loaded
429 from airflow.sdk.configuration import conf
430
431 ExecutorSafeguard.test_mode = conf.getboolean("core", "unit_test_mode")
432
433
434def _collect_from_input(value_or_values: None | C | Collection[C]) -> list[C]:
435 if not value_or_values:
436 return []
437 if isinstance(value_or_values, Collection):
438 return list(value_or_values)
439 return [value_or_values]
440
441
442class BaseOperatorMeta(abc.ABCMeta):
443 """Metaclass of BaseOperator."""
444
445 @classmethod
446 def _apply_defaults(cls, func: T) -> T:
447 """
448 Look for an argument named "default_args", and fill the unspecified arguments from it.
449
450 Since python2.* isn't clear about which arguments are missing when
451 calling a function, and that this can be quite confusing with multi-level
452 inheritance and argument defaults, this decorator also alerts with
453 specific information about the missing arguments.
454 """
455 # Cache inspect.signature for the wrapper closure to avoid calling it
456 # at every decorated invocation. This is separate sig_cache created
457 # per decoration, i.e. each function decorated using apply_defaults will
458 # have a different sig_cache.
459 sig_cache = inspect.signature(func)
460 non_variadic_params = {
461 name: param
462 for (name, param) in sig_cache.parameters.items()
463 if param.name != "self" and param.kind not in (param.VAR_POSITIONAL, param.VAR_KEYWORD)
464 }
465 non_optional_args = {
466 name
467 for name, param in non_variadic_params.items()
468 if param.default == param.empty and name != "task_id"
469 }
470
471 fixup_decorator_warning_stack(func)
472
473 @wraps(func)
474 def apply_defaults(self: BaseOperator, *args: Any, **kwargs: Any) -> Any:
475 from airflow.sdk.definitions._internal.contextmanager import DagContext, TaskGroupContext
476
477 if args:
478 raise TypeError("Use keyword arguments when initializing operators")
479
480 instantiated_from_mapped = kwargs.pop(
481 "_airflow_from_mapped",
482 getattr(self, "_BaseOperator__from_mapped", False),
483 )
484
485 dag: DAG | None = kwargs.get("dag")
486 if dag is None:
487 dag = DagContext.get_current()
488 if dag is not None:
489 kwargs["dag"] = dag
490
491 task_group: TaskGroup | None = kwargs.get("task_group")
492 if dag and not task_group:
493 task_group = TaskGroupContext.get_current(dag)
494 if task_group is not None:
495 kwargs["task_group"] = task_group
496
497 default_args, merged_params = get_merged_defaults(
498 dag=dag,
499 task_group=task_group,
500 task_params=kwargs.pop("params", None),
501 task_default_args=kwargs.pop("default_args", None),
502 )
503
504 for arg in sig_cache.parameters:
505 if arg not in kwargs and arg in default_args:
506 kwargs[arg] = default_args[arg]
507
508 missing_args = non_optional_args.difference(kwargs)
509 if len(missing_args) == 1:
510 raise TypeError(f"missing keyword argument {missing_args.pop()!r}")
511 if missing_args:
512 display = ", ".join(repr(a) for a in sorted(missing_args))
513 raise TypeError(f"missing keyword arguments {display}")
514
515 if merged_params:
516 kwargs["params"] = merged_params
517
518 hook = getattr(self, "_hook_apply_defaults", None)
519 if hook:
520 args, kwargs = hook(**kwargs, default_args=default_args)
521 default_args = kwargs.pop("default_args", {})
522
523 if not hasattr(self, "_BaseOperator__init_kwargs"):
524 object.__setattr__(self, "_BaseOperator__init_kwargs", {})
525 object.__setattr__(self, "_BaseOperator__from_mapped", instantiated_from_mapped)
526
527 result = func(self, **kwargs, default_args=default_args)
528
529 # Store the args passed to init -- we need them to support task.map serialization!
530 self._BaseOperator__init_kwargs.update(kwargs) # type: ignore
531
532 # Set upstream task defined by XComArgs passed to template fields of the operator.
533 # BUT: only do this _ONCE_, not once for each class in the hierarchy
534 if not instantiated_from_mapped and func == self.__init__.__wrapped__: # type: ignore[misc]
535 self._set_xcomargs_dependencies()
536 # Mark instance as instantiated so that future attr setting updates xcomarg-based deps.
537 object.__setattr__(self, "_BaseOperator__instantiated", True)
538
539 return result
540
541 apply_defaults.__non_optional_args = non_optional_args # type: ignore
542 apply_defaults.__param_names = set(non_variadic_params) # type: ignore
543
544 return cast("T", apply_defaults)
545
546 def __new__(cls, name, bases, namespace, **kwargs):
547 execute_method = namespace.get("execute")
548 if callable(execute_method) and not getattr(execute_method, "__isabstractmethod__", False):
549 namespace["execute"] = ExecutorSafeguard.decorator(execute_method)
550 new_cls = super().__new__(cls, name, bases, namespace, **kwargs)
551 with contextlib.suppress(KeyError):
552 # Update the partial descriptor with the class method, so it calls the actual function
553 # (but let subclasses override it if they need to)
554 partial_desc = vars(new_cls)["partial"]
555 if isinstance(partial_desc, _PartialDescriptor):
556 partial_desc.class_method = classmethod(partial)
557
558 # We patch `__init__` only if the class defines it.
559 first_superclass = new_cls.mro()[1]
560 if new_cls.__init__ is not first_superclass.__init__:
561 new_cls.__init__ = cls._apply_defaults(new_cls.__init__)
562
563 return new_cls
564
565
566# TODO: The following mapping is used to validate that the arguments passed to the BaseOperator are of the
567# correct type. This is a temporary solution until we find a more sophisticated method for argument
568# validation. One potential method is to use `get_type_hints` from the typing module. However, this is not
569# fully compatible with future annotations for Python versions below 3.10. Once we require a minimum Python
570# version that supports `get_type_hints` effectively or find a better approach, we can replace this
571# manual type-checking method.
572BASEOPERATOR_ARGS_EXPECTED_TYPES = {
573 "task_id": str,
574 "email": (str, Sequence),
575 "email_on_retry": bool,
576 "email_on_failure": bool,
577 "retries": int,
578 "retry_exponential_backoff": (int, float),
579 "depends_on_past": bool,
580 "ignore_first_depends_on_past": bool,
581 "wait_for_past_depends_before_skipping": bool,
582 "wait_for_downstream": bool,
583 "priority_weight": int,
584 "queue": str,
585 "pool": str,
586 "pool_slots": int,
587 "trigger_rule": str,
588 "run_as_user": str,
589 "task_concurrency": int,
590 "map_index_template": str,
591 "max_active_tis_per_dag": int,
592 "max_active_tis_per_dagrun": int,
593 "executor": str,
594 "do_xcom_push": bool,
595 "multiple_outputs": bool,
596 "doc": str,
597 "doc_md": str,
598 "doc_json": str,
599 "doc_yaml": str,
600 "doc_rst": str,
601 "task_display_name": str,
602 "logger_name": str,
603 "allow_nested_operators": bool,
604 "start_date": datetime,
605 "end_date": datetime,
606}
607
608
609# Note: BaseOperator is defined as a dataclass, and not an attrs class as we do too much metaprogramming in
610# here (metaclass, custom `__setattr__` behaviour) and this fights with attrs too much to make it worth it.
611#
612# To future reader: if you want to try and make this a "normal" attrs class, go ahead and attempt it. If you
613# get nowhere leave your record here for the next poor soul and what problems you ran in to.
614#
615# @ashb, 2024/10/14
616# - "Can't combine custom __setattr__ with on_setattr hooks"
617# - Setting class-wide `define(on_setarrs=...)` isn't called for non-attrs subclasses
618@total_ordering
619@dataclass(repr=False)
620class BaseOperator(AbstractOperator, metaclass=BaseOperatorMeta):
621 r"""
622 Abstract base class for all operators.
623
624 Since operators create objects that become nodes in the Dag, BaseOperator
625 contains many recursive methods for Dag crawling behavior. To derive from
626 this class, you are expected to override the constructor and the 'execute'
627 method.
628
629 Operators derived from this class should perform or trigger certain tasks
630 synchronously (wait for completion). Example of operators could be an
631 operator that runs a Pig job (PigOperator), a sensor operator that
632 waits for a partition to land in Hive (HiveSensorOperator), or one that
633 moves data from Hive to MySQL (Hive2MySqlOperator). Instances of these
634 operators (tasks) target specific operations, running specific scripts,
635 functions or data transfers.
636
637 This class is abstract and shouldn't be instantiated. Instantiating a
638 class derived from this one results in the creation of a task object,
639 which ultimately becomes a node in Dag objects. Task dependencies should
640 be set by using the set_upstream and/or set_downstream methods.
641
642 :param task_id: a unique, meaningful id for the task
643 :param owner: the owner of the task. Using a meaningful description
644 (e.g. user/person/team/role name) to clarify ownership is recommended.
645 :param email: the 'to' email address(es) used in email alerts. This can be a
646 single email or multiple ones. Multiple addresses can be specified as a
647 comma or semicolon separated string or by passing a list of strings. (deprecated)
648 :param email_on_retry: Indicates whether email alerts should be sent when a
649 task is retried (deprecated)
650 :param email_on_failure: Indicates whether email alerts should be sent when
651 a task failed (deprecated)
652 :param retries: the number of retries that should be performed before
653 failing the task
654 :param retry_delay: delay between retries, can be set as ``timedelta`` or
655 ``float`` seconds, which will be converted into ``timedelta``,
656 the default is ``timedelta(seconds=300)``.
657 :param retry_exponential_backoff: multiplier for exponential backoff between retries.
658 Set to 0 to disable (constant delay). Set to 2.0 for standard exponential backoff
659 (delay doubles with each retry). For example, with retry_delay=4min and
660 retry_exponential_backoff=5, retries occur after 4min, 20min, 100min, etc.
661 :param max_retry_delay: maximum delay interval between retries, can be set as
662 ``timedelta`` or ``float`` seconds, which will be converted into ``timedelta``.
663 :param start_date: The ``start_date`` for the task, determines
664 the ``logical_date`` for the first task instance. The best practice
665 is to have the start_date rounded
666 to your Dag's ``schedule_interval``. Daily jobs have their start_date
667 some day at 00:00:00, hourly jobs have their start_date at 00:00
668 of a specific hour. Note that Airflow simply looks at the latest
669 ``logical_date`` and adds the ``schedule_interval`` to determine
670 the next ``logical_date``. It is also very important
671 to note that different tasks' dependencies
672 need to line up in time. If task A depends on task B and their
673 start_date are offset in a way that their logical_date don't line
674 up, A's dependencies will never be met. If you are looking to delay
675 a task, for example running a daily task at 2AM, look into the
676 ``TimeSensor`` and ``TimeDeltaSensor``. We advise against using
677 dynamic ``start_date`` and recommend using fixed ones. Read the
678 FAQ entry about start_date for more information.
679 :param end_date: if specified, the scheduler won't go beyond this date
680 :param depends_on_past: when set to true, task instances will run
681 sequentially and only if the previous instance has succeeded or has been skipped.
682 The task instance for the start_date is allowed to run.
683 :param wait_for_past_depends_before_skipping: when set to true, if the task instance
684 should be marked as skipped, and depends_on_past is true, the ti will stay on None state
685 waiting the task of the previous run
686 :param wait_for_downstream: when set to true, an instance of task
687 X will wait for tasks immediately downstream of the previous instance
688 of task X to finish successfully or be skipped before it runs. This is useful if the
689 different instances of a task X alter the same asset, and this asset
690 is used by tasks downstream of task X. Note that depends_on_past
691 is forced to True wherever wait_for_downstream is used. Also note that
692 only tasks *immediately* downstream of the previous task instance are waited
693 for; the statuses of any tasks further downstream are ignored.
694 :param dag: a reference to the dag the task is attached to (if any)
695 :param priority_weight: priority weight of this task against other task.
696 This allows the executor to trigger higher priority tasks before
697 others when things get backed up. Set priority_weight as a higher
698 number for more important tasks.
699 As not all database engines support 64-bit integers, values are capped with 32-bit.
700 Valid range is from -2,147,483,648 to 2,147,483,647.
701 :param weight_rule: weighting method used for the effective total
702 priority weight of the task. Options are:
703 ``{ downstream | upstream | absolute }`` default is ``downstream``
704 When set to ``downstream`` the effective weight of the task is the
705 aggregate sum of all downstream descendants. As a result, upstream
706 tasks will have higher weight and will be scheduled more aggressively
707 when using positive weight values. This is useful when you have
708 multiple dag run instances and desire to have all upstream tasks to
709 complete for all runs before each dag can continue processing
710 downstream tasks. When set to ``upstream`` the effective weight is the
711 aggregate sum of all upstream ancestors. This is the opposite where
712 downstream tasks have higher weight and will be scheduled more
713 aggressively when using positive weight values. This is useful when you
714 have multiple dag run instances and prefer to have each dag complete
715 before starting upstream tasks of other dags. When set to
716 ``absolute``, the effective weight is the exact ``priority_weight``
717 specified without additional weighting. You may want to do this when
718 you know exactly what priority weight each task should have.
719 Additionally, when set to ``absolute``, there is bonus effect of
720 significantly speeding up the task creation process as for very large
721 Dags. Options can be set as string or using the constants defined in
722 the static class ``airflow.utils.WeightRule``.
723 Irrespective of the weight rule, resulting priority values are capped with 32-bit.
724 |experimental|
725 Since 2.9.0, Airflow allows to define custom priority weight strategy,
726 by creating a subclass of
727 ``airflow.task.priority_strategy.PriorityWeightStrategy`` and registering
728 in a plugin, then providing the class path or the class instance via
729 ``weight_rule`` parameter. The custom priority weight strategy will be
730 used to calculate the effective total priority weight of the task instance.
731 :param queue: which queue to target when running this job. Not
732 all executors implement queue management, the CeleryExecutor
733 does support targeting specific queues.
734 :param pool: the slot pool this task should run in, slot pools are a
735 way to limit concurrency for certain tasks
736 :param pool_slots: the number of pool slots this task should use (>= 1)
737 Values less than 1 are not allowed.
738 :param sla: DEPRECATED - The SLA feature is removed in Airflow 3.0, to be replaced with a
739 new implementation in Airflow >=3.1.
740 :param execution_timeout: max time allowed for the execution of
741 this task instance, if it goes beyond it will raise and fail.
742 :param on_failure_callback: a function or list of functions to be called when a task instance
743 of this task fails. a context dictionary is passed as a single
744 parameter to this function. Context contains references to related
745 objects to the task instance and is documented under the macros
746 section of the API.
747 :param on_execute_callback: much like the ``on_failure_callback`` except
748 that it is executed right before the task is executed.
749 :param on_retry_callback: much like the ``on_failure_callback`` except
750 that it is executed when retries occur.
751 :param on_success_callback: much like the ``on_failure_callback`` except
752 that it is executed when the task succeeds.
753 :param on_skipped_callback: much like the ``on_failure_callback`` except
754 that it is executed when skipped occur; this callback will be called only if AirflowSkipException get raised.
755 Explicitly it is NOT called if a task is not started to be executed because of a preceding branching
756 decision in the Dag or a trigger rule which causes execution to skip so that the task execution
757 is never scheduled.
758 :param pre_execute: a function to be called immediately before task
759 execution, receiving a context dictionary; raising an exception will
760 prevent the task from being executed.
761
762 |experimental|
763 :param post_execute: a function to be called immediately after task
764 execution, receiving a context dictionary and task result; raising an
765 exception will prevent the task from succeeding.
766
767 |experimental|
768 :param trigger_rule: defines the rule by which dependencies are applied
769 for the task to get triggered. Options are:
770 ``{ all_success | all_failed | all_done | all_skipped | one_success | one_done |
771 one_failed | none_failed | none_failed_min_one_success | none_skipped | always}``
772 default is ``all_success``. Options can be set as string or
773 using the constants defined in the static class
774 ``airflow.utils.TriggerRule``
775 :param resources: A map of resource parameter names (the argument names of the
776 Resources constructor) to their values.
777 :param run_as_user: unix username to impersonate while running the task
778 :param max_active_tis_per_dag: When set, a task will be able to limit the concurrent
779 runs across logical_dates.
780 :param max_active_tis_per_dagrun: When set, a task will be able to limit the concurrent
781 task instances per Dag run.
782 :param executor: Which executor to target when running this task. NOT YET SUPPORTED
783 :param executor_config: Additional task-level configuration parameters that are
784 interpreted by a specific executor. Parameters are namespaced by the name of
785 executor.
786
787 **Example**: to run this task in a specific docker container through
788 the KubernetesExecutor ::
789
790 MyOperator(..., executor_config={"KubernetesExecutor": {"image": "myCustomDockerImage"}})
791
792 :param do_xcom_push: if True, an XCom is pushed containing the Operator's
793 result
794 :param multiple_outputs: if True and do_xcom_push is True, pushes multiple XComs, one for each
795 key in the returned dictionary result. If False and do_xcom_push is True, pushes a single XCom.
796 :param task_group: The TaskGroup to which the task should belong. This is typically provided when not
797 using a TaskGroup as a context manager.
798 :param doc: Add documentation or notes to your Task objects that is visible in
799 Task Instance details View in the Webserver
800 :param doc_md: Add documentation (in Markdown format) or notes to your Task objects
801 that is visible in Task Instance details View in the Webserver
802 :param doc_rst: Add documentation (in RST format) or notes to your Task objects
803 that is visible in Task Instance details View in the Webserver
804 :param doc_json: Add documentation (in JSON format) or notes to your Task objects
805 that is visible in Task Instance details View in the Webserver
806 :param doc_yaml: Add documentation (in YAML format) or notes to your Task objects
807 that is visible in Task Instance details View in the Webserver
808 :param task_display_name: The display name of the task which appears on the UI.
809 :param logger_name: Name of the logger used by the Operator to emit logs.
810 If set to `None` (default), the logger name will fall back to
811 `airflow.task.operators.{class.__module__}.{class.__name__}` (e.g. HttpOperator will have
812 *airflow.task.operators.airflow.providers.http.operators.http.HttpOperator* as logger).
813 :param allow_nested_operators: if True, when an operator is executed within another one a warning message
814 will be logged. If False, then an exception will be raised if the operator is badly used (e.g. nested
815 within another one). In future releases of Airflow this parameter will be removed and an exception
816 will always be thrown when operators are nested within each other (default is True).
817
818 **Example**: example of a bad operator mixin usage::
819
820 @task(provide_context=True)
821 def say_hello_world(**context):
822 hello_world_task = BashOperator(
823 task_id="hello_world_task",
824 bash_command="python -c \"print('Hello, world!')\"",
825 dag=dag,
826 )
827 hello_world_task.execute(context)
828 """
829
830 task_id: str
831 owner: str = DEFAULT_OWNER
832 email: str | Sequence[str] | None = None
833 email_on_retry: bool = True
834 email_on_failure: bool = True
835 retries: int | None = DEFAULT_RETRIES
836 retry_delay: timedelta = DEFAULT_RETRY_DELAY
837 retry_exponential_backoff: float = 0
838 max_retry_delay: timedelta | float | None = None
839 start_date: datetime | None = None
840 end_date: datetime | None = None
841 depends_on_past: bool = False
842 ignore_first_depends_on_past: bool = DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST
843 wait_for_past_depends_before_skipping: bool = DEFAULT_WAIT_FOR_PAST_DEPENDS_BEFORE_SKIPPING
844 wait_for_downstream: bool = False
845
846 # At execution_time this becomes a normal dict
847 params: ParamsDict | dict = field(default_factory=ParamsDict)
848 default_args: dict | None = None
849 priority_weight: int = DEFAULT_PRIORITY_WEIGHT
850 weight_rule: PriorityWeightStrategy = field(
851 default_factory=airflow_priority_weight_strategies[DEFAULT_WEIGHT_RULE]
852 )
853 queue: str = DEFAULT_QUEUE
854 pool: str = DEFAULT_POOL_NAME
855 pool_slots: int = DEFAULT_POOL_SLOTS
856 execution_timeout: timedelta | None = DEFAULT_TASK_EXECUTION_TIMEOUT
857 on_execute_callback: Sequence[TaskStateChangeCallback] = ()
858 on_failure_callback: Sequence[TaskStateChangeCallback] = ()
859 on_success_callback: Sequence[TaskStateChangeCallback] = ()
860 on_retry_callback: Sequence[TaskStateChangeCallback] = ()
861 on_skipped_callback: Sequence[TaskStateChangeCallback] = ()
862 _pre_execute_hook: TaskPreExecuteHook | None = None
863 _post_execute_hook: TaskPostExecuteHook | None = None
864 trigger_rule: TriggerRule = DEFAULT_TRIGGER_RULE
865 resources: dict[str, Any] | None = None
866 run_as_user: str | None = None
867 task_concurrency: int | None = None
868 map_index_template: str | None = None
869 max_active_tis_per_dag: int | None = None
870 max_active_tis_per_dagrun: int | None = None
871 executor: str | None = None
872 executor_config: dict | None = None
873 do_xcom_push: bool = True
874 multiple_outputs: bool = False
875 inlets: list[Any] = field(default_factory=list)
876 outlets: list[Any] = field(default_factory=list)
877 task_group: TaskGroup | None = None
878 doc: str | None = None
879 doc_md: str | None = None
880 doc_json: str | None = None
881 doc_yaml: str | None = None
882 doc_rst: str | None = None
883 _task_display_name: str | None = None
884 logger_name: str | None = None
885 allow_nested_operators: bool = True
886
887 is_setup: bool = False
888 is_teardown: bool = False
889
890 # TODO: Task-SDK: Make these ClassVar[]?
891 template_fields: Collection[str] = ()
892 template_ext: Sequence[str] = ()
893
894 template_fields_renderers: ClassVar[dict[str, str]] = {}
895
896 operator_extra_links: Collection[BaseOperatorLink] = ()
897
898 # Defines the color in the UI
899 ui_color: str = "#fff"
900 ui_fgcolor: str = "#000"
901
902 partial: Callable[..., OperatorPartial] = _PartialDescriptor() # type: ignore
903
904 _dag: DAG | None = field(init=False, default=None)
905
906 # Make this optional so the type matches the one define in LoggingMixin
907 _log_config_logger_name: str | None = field(default="airflow.task.operators", init=False)
908 _logger_name: str | None = None
909
910 # The _serialized_fields are lazily loaded when get_serialized_fields() method is called
911 __serialized_fields: ClassVar[frozenset[str] | None] = None
912
913 _comps: ClassVar[set[str]] = {
914 "task_id",
915 "dag_id",
916 "owner",
917 "email",
918 "email_on_retry",
919 "retry_delay",
920 "retry_exponential_backoff",
921 "max_retry_delay",
922 "start_date",
923 "end_date",
924 "depends_on_past",
925 "wait_for_downstream",
926 "priority_weight",
927 "execution_timeout",
928 "has_on_execute_callback",
929 "has_on_failure_callback",
930 "has_on_success_callback",
931 "has_on_retry_callback",
932 "has_on_skipped_callback",
933 "do_xcom_push",
934 "multiple_outputs",
935 "allow_nested_operators",
936 "executor",
937 }
938
939 # If True, the Rendered Template fields will be overwritten in DB after execution
940 # This is useful for Taskflow decorators that modify the template fields during execution like
941 # @task.bash decorator.
942 overwrite_rtif_after_execution: bool = False
943
944 # If True then the class constructor was called
945 __instantiated: bool = False
946 # List of args as passed to `init()`, after apply_defaults() has been updated. Used to "recreate" the task
947 # when mapping
948 # Set via the metaclass
949 __init_kwargs: dict[str, Any] = field(init=False)
950
951 # Set to True before calling execute method
952 _lock_for_execution: bool = False
953
954 # Set to True for an operator instantiated by a mapped operator.
955 __from_mapped: bool = False
956
957 start_trigger_args: StartTriggerArgs | None = None
958 start_from_trigger: bool = False
959
960 # base list which includes all the attrs that don't need deep copy.
961 _base_operator_shallow_copy_attrs: Final[tuple[str, ...]] = (
962 "user_defined_macros",
963 "user_defined_filters",
964 "params",
965 )
966
967 # each operator should override this class attr for shallow copy attrs.
968 shallow_copy_attrs: Sequence[str] = ()
969
970 def __setattr__(self: BaseOperator, key: str, value: Any):
971 if converter := getattr(self, f"_convert_{key}", None):
972 value = converter(value)
973 super().__setattr__(key, value)
974 if self.__from_mapped or self._lock_for_execution:
975 return # Skip any custom behavior for validation and during execute.
976 if key in self.__init_kwargs:
977 self.__init_kwargs[key] = value
978 if self.__instantiated and key in self.template_fields:
979 # Resolve upstreams set by assigning an XComArg after initializing
980 # an operator, example:
981 # op = BashOperator()
982 # op.bash_command = "sleep 1"
983 self._set_xcomargs_dependency(key, value)
984
985 def __init__(
986 self,
987 *,
988 task_id: str,
989 owner: str = DEFAULT_OWNER,
990 email: str | Sequence[str] | None = None,
991 email_on_retry: bool = True,
992 email_on_failure: bool = True,
993 retries: int | None = DEFAULT_RETRIES,
994 retry_delay: timedelta | float = DEFAULT_RETRY_DELAY,
995 retry_exponential_backoff: float = 0,
996 max_retry_delay: timedelta | float | None = None,
997 start_date: datetime | None = None,
998 end_date: datetime | None = None,
999 depends_on_past: bool = False,
1000 ignore_first_depends_on_past: bool = DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST,
1001 wait_for_past_depends_before_skipping: bool = DEFAULT_WAIT_FOR_PAST_DEPENDS_BEFORE_SKIPPING,
1002 wait_for_downstream: bool = False,
1003 dag: DAG | None = None,
1004 params: collections.abc.MutableMapping[str, Any] | None = None,
1005 default_args: dict | None = None,
1006 priority_weight: int = DEFAULT_PRIORITY_WEIGHT,
1007 weight_rule: str | PriorityWeightStrategy = DEFAULT_WEIGHT_RULE,
1008 queue: str = DEFAULT_QUEUE,
1009 pool: str | None = None,
1010 pool_slots: int = DEFAULT_POOL_SLOTS,
1011 sla: timedelta | None = None,
1012 execution_timeout: timedelta | None = DEFAULT_TASK_EXECUTION_TIMEOUT,
1013 on_execute_callback: None | TaskStateChangeCallback | Collection[TaskStateChangeCallback] = None,
1014 on_failure_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = None,
1015 on_success_callback: None | TaskStateChangeCallback | Collection[TaskStateChangeCallback] = None,
1016 on_retry_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = None,
1017 on_skipped_callback: None | TaskStateChangeCallback | Collection[TaskStateChangeCallback] = None,
1018 pre_execute: TaskPreExecuteHook | None = None,
1019 post_execute: TaskPostExecuteHook | None = None,
1020 trigger_rule: str = DEFAULT_TRIGGER_RULE,
1021 resources: dict[str, Any] | None = None,
1022 run_as_user: str | None = None,
1023 map_index_template: str | None = None,
1024 max_active_tis_per_dag: int | None = None,
1025 max_active_tis_per_dagrun: int | None = None,
1026 executor: str | None = None,
1027 executor_config: dict | None = None,
1028 do_xcom_push: bool = True,
1029 multiple_outputs: bool = False,
1030 inlets: Any | None = None,
1031 outlets: Any | None = None,
1032 task_group: TaskGroup | None = None,
1033 doc: str | None = None,
1034 doc_md: str | None = None,
1035 doc_json: str | None = None,
1036 doc_yaml: str | None = None,
1037 doc_rst: str | None = None,
1038 task_display_name: str | None = None,
1039 logger_name: str | None = None,
1040 allow_nested_operators: bool = True,
1041 **kwargs: Any,
1042 ):
1043 # Note: Metaclass handles passing in the Dag/TaskGroup from active context manager, if any
1044
1045 # Only apply task_group prefix if this operator was not created from a mapped operator
1046 # Mapped operators already have the prefix applied during their creation
1047 if task_group and not self.__from_mapped:
1048 self.task_id = task_group.child_id(task_id)
1049 task_group.add(self)
1050 else:
1051 self.task_id = task_id
1052
1053 super().__init__()
1054 self.task_group = task_group
1055
1056 kwargs.pop("_airflow_mapped_validation_only", None)
1057 if kwargs:
1058 raise TypeError(
1059 f"Invalid arguments were passed to {self.__class__.__name__} (task_id: {task_id}). "
1060 f"Invalid arguments were:\n**kwargs: {redact(kwargs)}",
1061 )
1062 validate_key(self.task_id)
1063
1064 self.owner = owner
1065 self.email = email
1066 self.email_on_retry = email_on_retry
1067 self.email_on_failure = email_on_failure
1068
1069 if email is not None:
1070 warnings.warn(
1071 "Setting email on a task is deprecated; please migrate to SmtpNotifier.",
1072 RemovedInAirflow4Warning,
1073 stacklevel=2,
1074 )
1075 if email and email_on_retry is not None:
1076 warnings.warn(
1077 "Setting email_on_retry on a task is deprecated; please migrate to SmtpNotifier.",
1078 RemovedInAirflow4Warning,
1079 stacklevel=2,
1080 )
1081 if email and email_on_failure is not None:
1082 warnings.warn(
1083 "Setting email_on_failure on a task is deprecated; please migrate to SmtpNotifier.",
1084 RemovedInAirflow4Warning,
1085 stacklevel=2,
1086 )
1087
1088 if execution_timeout is not None and not isinstance(execution_timeout, timedelta):
1089 raise ValueError(
1090 f"execution_timeout must be timedelta object but passed as type: {type(execution_timeout)}"
1091 )
1092 self.execution_timeout = execution_timeout
1093
1094 self.on_execute_callback = _collect_from_input(on_execute_callback)
1095 self.on_failure_callback = _collect_from_input(on_failure_callback)
1096 self.on_success_callback = _collect_from_input(on_success_callback)
1097 self.on_retry_callback = _collect_from_input(on_retry_callback)
1098 self.on_skipped_callback = _collect_from_input(on_skipped_callback)
1099 self._pre_execute_hook = pre_execute
1100 self._post_execute_hook = post_execute
1101
1102 self.start_date = timezone.convert_to_utc(start_date)
1103 self.end_date = timezone.convert_to_utc(end_date)
1104 self.executor = executor
1105 self.executor_config = executor_config or {}
1106 self.run_as_user = run_as_user
1107 # TODO:
1108 # self.retries = parse_retries(retries)
1109 self.retries = retries
1110 self.queue = queue
1111 self.pool = DEFAULT_POOL_NAME if pool is None else pool
1112 self.pool_slots = pool_slots
1113 if self.pool_slots < 1:
1114 dag_str = f" in dag {dag.dag_id}" if dag else ""
1115 raise ValueError(f"pool slots for {self.task_id}{dag_str} cannot be less than 1")
1116 if sla is not None:
1117 warnings.warn(
1118 "The SLA feature is removed in Airflow 3.0, replaced with Deadline Alerts in >=3.1",
1119 stacklevel=2,
1120 )
1121
1122 try:
1123 TriggerRule(trigger_rule)
1124 except ValueError:
1125 raise ValueError(
1126 f"The trigger_rule must be one of {[rule.value for rule in TriggerRule]},"
1127 f"'{dag.dag_id if dag else ''}.{task_id}'; received '{trigger_rule}'."
1128 )
1129
1130 self.trigger_rule: TriggerRule = TriggerRule(trigger_rule)
1131
1132 self.depends_on_past: bool = depends_on_past
1133 self.ignore_first_depends_on_past: bool = ignore_first_depends_on_past
1134 self.wait_for_past_depends_before_skipping: bool = wait_for_past_depends_before_skipping
1135 self.wait_for_downstream: bool = wait_for_downstream
1136 if wait_for_downstream:
1137 self.depends_on_past = True
1138
1139 # Converted by setattr
1140 self.retry_delay = retry_delay # type: ignore[assignment]
1141 self.retry_exponential_backoff = retry_exponential_backoff
1142 if max_retry_delay is not None:
1143 self.max_retry_delay = max_retry_delay
1144
1145 self.resources = resources
1146
1147 self.params = ParamsDict(params)
1148
1149 self.priority_weight = priority_weight
1150 self.weight_rule = validate_and_load_priority_weight_strategy(weight_rule)
1151
1152 self.max_active_tis_per_dag: int | None = max_active_tis_per_dag
1153 self.max_active_tis_per_dagrun: int | None = max_active_tis_per_dagrun
1154 self.do_xcom_push: bool = do_xcom_push
1155 self.map_index_template: str | None = map_index_template
1156 self.multiple_outputs: bool = multiple_outputs
1157
1158 self.doc_md = doc_md
1159 self.doc_json = doc_json
1160 self.doc_yaml = doc_yaml
1161 self.doc_rst = doc_rst
1162 self.doc = doc
1163
1164 self._task_display_name = task_display_name
1165
1166 self.allow_nested_operators = allow_nested_operators
1167
1168 self._logger_name = logger_name
1169
1170 # Lineage
1171 self.inlets = _collect_from_input(inlets)
1172 self.outlets = _collect_from_input(outlets)
1173
1174 if isinstance(self.template_fields, str):
1175 warnings.warn(
1176 f"The `template_fields` value for {self.task_type} is a string "
1177 "but should be a list or tuple of string. Wrapping it in a list for execution. "
1178 f"Please update {self.task_type} accordingly.",
1179 UserWarning,
1180 stacklevel=2,
1181 )
1182 self.template_fields = [self.template_fields]
1183
1184 self.is_setup = False
1185 self.is_teardown = False
1186
1187 if SetupTeardownContext.active:
1188 SetupTeardownContext.update_context_map(self)
1189
1190 # We set self.dag right at the end as `_convert_dag` calls `dag.add_task` for us, and we need all the
1191 # other properties to be set at that point
1192 if dag is not None:
1193 self.dag = dag
1194
1195 validate_instance_args(self, BASEOPERATOR_ARGS_EXPECTED_TYPES)
1196
1197 # Ensure priority_weight is within the valid range
1198 self.priority_weight = db_safe_priority(self.priority_weight)
1199
1200 def __eq__(self, other):
1201 if type(self) is type(other):
1202 # Use getattr() instead of __dict__ as __dict__ doesn't return
1203 # correct values for properties.
1204 return all(getattr(self, c, None) == getattr(other, c, None) for c in self._comps)
1205 return False
1206
1207 def __ne__(self, other):
1208 return not self == other
1209
1210 def __hash__(self):
1211 hash_components = [type(self)]
1212 for component in self._comps:
1213 val = getattr(self, component, None)
1214 try:
1215 hash(val)
1216 hash_components.append(val)
1217 except TypeError:
1218 hash_components.append(repr(val))
1219 return hash(tuple(hash_components))
1220
1221 # /Composing Operators ---------------------------------------------
1222
1223 def __gt__(self, other):
1224 """
1225 Return [Operator] > [Outlet].
1226
1227 If other is an attr annotated object it is set as an outlet of this Operator.
1228 """
1229 if not isinstance(other, Iterable):
1230 other = [other]
1231
1232 for obj in other:
1233 if not attrs.has(obj):
1234 raise TypeError(f"Left hand side ({obj}) is not an outlet")
1235 self.add_outlets(other)
1236
1237 return self
1238
1239 def __lt__(self, other):
1240 """
1241 Return [Inlet] > [Operator] or [Operator] < [Inlet].
1242
1243 If other is an attr annotated object it is set as an inlet to this operator.
1244 """
1245 if not isinstance(other, Iterable):
1246 other = [other]
1247
1248 for obj in other:
1249 if not attrs.has(obj):
1250 raise TypeError(f"{obj} cannot be an inlet")
1251 self.add_inlets(other)
1252
1253 return self
1254
1255 def __deepcopy__(self, memo: dict[int, Any]):
1256 # Hack sorting double chained task lists by task_id to avoid hitting
1257 # max_depth on deepcopy operations.
1258 sys.setrecursionlimit(5000) # TODO fix this in a better way
1259
1260 cls = self.__class__
1261 result = cls.__new__(cls)
1262 memo[id(self)] = result
1263
1264 shallow_copy = tuple(cls.shallow_copy_attrs) + cls._base_operator_shallow_copy_attrs
1265
1266 for k, v_org in self.__dict__.items():
1267 if k not in shallow_copy:
1268 v = copy.deepcopy(v_org, memo)
1269 else:
1270 v = copy.copy(v_org)
1271
1272 # Bypass any setters, and set it on the object directly. This works since we are cloning ourself so
1273 # we know the type is already fine
1274 result.__dict__[k] = v
1275 return result
1276
1277 def __getstate__(self):
1278 state = dict(self.__dict__)
1279 if "_log" in state:
1280 del state["_log"]
1281
1282 return state
1283
1284 def __setstate__(self, state):
1285 self.__dict__ = state
1286
1287 def add_inlets(self, inlets: Iterable[Any]):
1288 """Set inlets to this operator."""
1289 self.inlets.extend(inlets)
1290
1291 def add_outlets(self, outlets: Iterable[Any]):
1292 """Define the outlets of this operator."""
1293 self.outlets.extend(outlets)
1294
1295 def get_dag(self) -> DAG | None:
1296 return self._dag
1297
1298 @property
1299 def dag(self) -> DAG:
1300 """Returns the Operator's Dag if set, otherwise raises an error."""
1301 if dag := self._dag:
1302 return dag
1303 raise RuntimeError(f"Operator {self} has not been assigned to a Dag yet")
1304
1305 @dag.setter
1306 def dag(self, dag: DAG | None) -> None:
1307 """Operators can be assigned to one Dag, one time. Repeat assignments to that same Dag are ok."""
1308 self._dag = dag
1309
1310 def _convert__dag(self, dag: DAG | None) -> DAG | None:
1311 # Called automatically by __setattr__ method
1312 from airflow.sdk.definitions.dag import DAG
1313
1314 if dag is None:
1315 return dag
1316
1317 if not isinstance(dag, DAG):
1318 raise TypeError(f"Expected dag; received {dag.__class__.__name__}")
1319 if self._dag is not None and self._dag is not dag:
1320 raise ValueError(f"The dag assigned to {self} can not be changed.")
1321
1322 if self.__from_mapped:
1323 pass # Don't add to dag -- the mapped task takes the place.
1324 elif dag.task_dict.get(self.task_id) is not self:
1325 dag.add_task(self)
1326 return dag
1327
1328 @staticmethod
1329 def _convert_retries(retries: Any) -> int | None:
1330 if retries is None:
1331 return 0
1332 if type(retries) == int: # noqa: E721
1333 return retries
1334 try:
1335 parsed_retries = int(retries)
1336 except (TypeError, ValueError):
1337 raise TypeError(f"'retries' type must be int, not {type(retries).__name__}")
1338 return parsed_retries
1339
1340 @staticmethod
1341 def _convert_timedelta(value: float | timedelta | None) -> timedelta | None:
1342 if value is None or isinstance(value, timedelta):
1343 return value
1344 return timedelta(seconds=value)
1345
1346 _convert_retry_delay = _convert_timedelta
1347 _convert_max_retry_delay = _convert_timedelta
1348
1349 @staticmethod
1350 def _convert_resources(resources: dict[str, Any] | None) -> Resources | None:
1351 if resources is None:
1352 return None
1353
1354 from airflow.sdk.definitions.operator_resources import Resources
1355
1356 if isinstance(resources, Resources):
1357 return resources
1358
1359 return Resources(**resources)
1360
1361 def _convert_is_setup(self, value: bool) -> bool:
1362 """
1363 Setter for is_setup property.
1364
1365 :meta private:
1366 """
1367 if self.is_teardown and value:
1368 raise ValueError(f"Cannot mark task '{self.task_id}' as setup; task is already a teardown.")
1369 return value
1370
1371 def _convert_is_teardown(self, value: bool) -> bool:
1372 if self.is_setup and value:
1373 raise ValueError(f"Cannot mark task '{self.task_id}' as teardown; task is already a setup.")
1374 return value
1375
1376 @property
1377 def task_display_name(self) -> str:
1378 return self._task_display_name or self.task_id
1379
1380 def has_dag(self):
1381 """Return True if the Operator has been assigned to a Dag."""
1382 return self._dag is not None
1383
1384 def _set_xcomargs_dependencies(self) -> None:
1385 from airflow.sdk.definitions.xcom_arg import XComArg
1386
1387 for f in self.template_fields:
1388 arg = getattr(self, f, NOTSET)
1389 if arg is not NOTSET:
1390 XComArg.apply_upstream_relationship(self, arg)
1391
1392 def _set_xcomargs_dependency(self, field: str, newvalue: Any) -> None:
1393 """
1394 Resolve upstream dependencies of a task.
1395
1396 In this way passing an ``XComArg`` as value for a template field
1397 will result in creating upstream relation between two tasks.
1398
1399 **Example**: ::
1400
1401 with DAG(...):
1402 generate_content = GenerateContentOperator(task_id="generate_content")
1403 send_email = EmailOperator(..., html_content=generate_content.output)
1404
1405 # This is equivalent to
1406 with DAG(...):
1407 generate_content = GenerateContentOperator(task_id="generate_content")
1408 send_email = EmailOperator(
1409 ..., html_content="{{ task_instance.xcom_pull('generate_content') }}"
1410 )
1411 generate_content >> send_email
1412
1413 """
1414 from airflow.sdk.definitions.xcom_arg import XComArg
1415
1416 if field not in self.template_fields:
1417 return
1418 XComArg.apply_upstream_relationship(self, newvalue)
1419
1420 def on_kill(self) -> None:
1421 """
1422 Override this method to clean up subprocesses when a task instance gets killed.
1423
1424 Any use of the threading, subprocess or multiprocessing module within an
1425 operator needs to be cleaned up, or it will leave ghost processes behind.
1426 """
1427
1428 def __repr__(self):
1429 return f"<Task({self.task_type}): {self.task_id}>"
1430
1431 @property
1432 def operator_class(self) -> type[BaseOperator]: # type: ignore[override]
1433 return self.__class__
1434
1435 @property
1436 def task_type(self) -> str:
1437 """@property: type of the task."""
1438 return self.__class__.__name__
1439
1440 @property
1441 def operator_name(self) -> str:
1442 """@property: use a more friendly display name for the operator, if set."""
1443 try:
1444 return self.custom_operator_name # type: ignore
1445 except AttributeError:
1446 return self.task_type
1447
1448 @property
1449 def roots(self) -> list[BaseOperator]:
1450 """Required by DAGNode."""
1451 return [self]
1452
1453 @property
1454 def leaves(self) -> list[BaseOperator]:
1455 """Required by DAGNode."""
1456 return [self]
1457
1458 @property
1459 def output(self) -> XComArg:
1460 """Returns reference to XCom pushed by current operator."""
1461 from airflow.sdk.definitions.xcom_arg import XComArg
1462
1463 return XComArg(operator=self)
1464
1465 @classmethod
1466 def get_serialized_fields(cls):
1467 """Stringified Dags and operators contain exactly these fields."""
1468 if not cls.__serialized_fields:
1469 from airflow.sdk.definitions._internal.contextmanager import DagContext
1470
1471 # make sure the following "fake" task is not added to current active
1472 # dag in context, otherwise, it will result in
1473 # `RuntimeError: dictionary changed size during iteration`
1474 # Exception in SerializedDAG.serialize_dag() call.
1475 DagContext.push(None)
1476 cls.__serialized_fields = frozenset(
1477 vars(BaseOperator(task_id="test")).keys()
1478 - {
1479 "upstream_task_ids",
1480 "default_args",
1481 "dag",
1482 "_dag",
1483 "label",
1484 "_BaseOperator__instantiated",
1485 "_BaseOperator__init_kwargs",
1486 "_BaseOperator__from_mapped",
1487 "on_failure_fail_dagrun",
1488 "task_group",
1489 "_task_type",
1490 "operator_extra_links",
1491 "on_execute_callback",
1492 "on_failure_callback",
1493 "on_success_callback",
1494 "on_retry_callback",
1495 "on_skipped_callback",
1496 }
1497 | { # Class level defaults, or `@property` need to be added to this list
1498 "start_date",
1499 "end_date",
1500 "task_type",
1501 "ui_color",
1502 "ui_fgcolor",
1503 "template_ext",
1504 "template_fields",
1505 "template_fields_renderers",
1506 "params",
1507 "is_setup",
1508 "is_teardown",
1509 "on_failure_fail_dagrun",
1510 "map_index_template",
1511 "start_trigger_args",
1512 "_needs_expansion",
1513 "start_from_trigger",
1514 "max_retry_delay",
1515 "has_on_execute_callback",
1516 "has_on_failure_callback",
1517 "has_on_success_callback",
1518 "has_on_retry_callback",
1519 "has_on_skipped_callback",
1520 }
1521 )
1522 DagContext.pop()
1523
1524 return cls.__serialized_fields
1525
1526 def prepare_for_execution(self) -> Self:
1527 """Lock task for execution to disable custom action in ``__setattr__`` and return a copy."""
1528 other = copy.copy(self)
1529 other._lock_for_execution = True
1530 return other
1531
1532 def serialize_for_task_group(self) -> tuple[DagAttributeTypes, Any]:
1533 """Serialize; required by DAGNode."""
1534 from airflow.serialization.enums import DagAttributeTypes
1535
1536 return DagAttributeTypes.OP, self.task_id
1537
1538 def unmap(self, resolve: None | Mapping[str, Any]) -> Self:
1539 """
1540 Get the "normal" operator from the current operator.
1541
1542 Since a BaseOperator is not mapped to begin with, this simply returns
1543 the original operator.
1544
1545 :meta private:
1546 """
1547 return self
1548
1549 def expand_start_trigger_args(self, *, context: Context) -> StartTriggerArgs | None:
1550 """
1551 Get the start_trigger_args value of the current abstract operator.
1552
1553 Since a BaseOperator is not mapped to begin with, this simply returns
1554 the original value of start_trigger_args.
1555
1556 :meta private:
1557 """
1558 return self.start_trigger_args
1559
1560 def render_template_fields(
1561 self,
1562 context: Context,
1563 jinja_env: jinja2.Environment | None = None,
1564 ) -> None:
1565 """
1566 Template all attributes listed in *self.template_fields*.
1567
1568 This mutates the attributes in-place and is irreversible.
1569
1570 :param context: Context dict with values to apply on content.
1571 :param jinja_env: Jinja's environment to use for rendering.
1572 """
1573 if not jinja_env:
1574 jinja_env = self.get_template_env()
1575 self._do_render_template_fields(self, self.template_fields, context, jinja_env, set())
1576
1577 def pre_execute(self, context: Any):
1578 """Execute right before self.execute() is called."""
1579
1580 def execute(self, context: Context) -> Any:
1581 """
1582 Derive when creating an operator.
1583
1584 The main method to execute the task. Context is the same dictionary used
1585 as when rendering jinja templates.
1586
1587 Refer to get_template_context for more context.
1588 """
1589 raise NotImplementedError()
1590
1591 def post_execute(self, context: Any, result: Any = None):
1592 """
1593 Execute right after self.execute() is called.
1594
1595 It is passed the execution context and any results returned by the operator.
1596 """
1597
1598 def defer(
1599 self,
1600 *,
1601 trigger: BaseTrigger,
1602 method_name: str,
1603 kwargs: dict[str, Any] | None = None,
1604 timeout: timedelta | int | float | None = None,
1605 ) -> NoReturn:
1606 """
1607 Mark this Operator "deferred", suspending its execution until the provided trigger fires an event.
1608
1609 This is achieved by raising a special exception (TaskDeferred)
1610 which is caught in the main _execute_task wrapper. Triggers can send execution back to task or end
1611 the task instance directly. If the trigger will end the task instance itself, ``method_name`` should
1612 be None; otherwise, provide the name of the method that should be used when resuming execution in
1613 the task.
1614 """
1615 from airflow.sdk.exceptions import TaskDeferred
1616
1617 raise TaskDeferred(trigger=trigger, method_name=method_name, kwargs=kwargs, timeout=timeout)
1618
1619 def resume_execution(self, next_method: str, next_kwargs: dict[str, Any] | None, context: Context):
1620 """Entrypoint method called by the Task Runner (instead of execute) when this task is resumed."""
1621 from airflow.sdk.exceptions import TaskDeferralError, TaskDeferralTimeout
1622
1623 if next_kwargs is None:
1624 next_kwargs = {}
1625 # __fail__ is a special signal value for next_method that indicates
1626 # this task was scheduled specifically to fail.
1627
1628 if next_method == TRIGGER_FAIL_REPR:
1629 next_kwargs = next_kwargs or {}
1630 traceback = next_kwargs.get("traceback")
1631 if traceback is not None:
1632 self.log.error("Trigger failed:\n%s", "\n".join(traceback))
1633 if (error := next_kwargs.get("error", "Unknown")) == TriggerFailureReason.TRIGGER_TIMEOUT:
1634 raise TaskDeferralTimeout(error)
1635 raise TaskDeferralError(error)
1636 # Grab the callable off the Operator/Task and add in any kwargs
1637 execute_callable = getattr(self, next_method)
1638 return execute_callable(context, **next_kwargs)
1639
1640 def dry_run(self) -> None:
1641 """Perform dry run for the operator - just render template fields."""
1642 self.log.info("Dry run")
1643 for f in self.template_fields:
1644 try:
1645 content = getattr(self, f)
1646 except AttributeError:
1647 raise AttributeError(
1648 f"{f!r} is configured as a template field "
1649 f"but {self.task_type} does not have this attribute."
1650 )
1651
1652 if content and isinstance(content, str):
1653 self.log.info("Rendering template for %s", f)
1654 self.log.info(content)
1655
1656 @property
1657 def has_on_execute_callback(self) -> bool:
1658 """Return True if the task has execute callbacks."""
1659 return bool(self.on_execute_callback)
1660
1661 @property
1662 def has_on_failure_callback(self) -> bool:
1663 """Return True if the task has failure callbacks."""
1664 return bool(self.on_failure_callback)
1665
1666 @property
1667 def has_on_success_callback(self) -> bool:
1668 """Return True if the task has success callbacks."""
1669 return bool(self.on_success_callback)
1670
1671 @property
1672 def has_on_retry_callback(self) -> bool:
1673 """Return True if the task has retry callbacks."""
1674 return bool(self.on_retry_callback)
1675
1676 @property
1677 def has_on_skipped_callback(self) -> bool:
1678 """Return True if the task has skipped callbacks."""
1679 return bool(self.on_skipped_callback)
1680
1681
1682def chain(*tasks: DependencyMixin | Sequence[DependencyMixin]) -> None:
1683 r"""
1684 Given a number of tasks, builds a dependency chain.
1685
1686 This function accepts values of BaseOperator (aka tasks), EdgeModifiers (aka Labels), XComArg, TaskGroups,
1687 or lists containing any mix of these types (or a mix in the same list). If you want to chain between two
1688 lists you must ensure they have the same length.
1689
1690 Using classic operators/sensors:
1691
1692 .. code-block:: python
1693
1694 chain(t1, [t2, t3], [t4, t5], t6)
1695
1696 is equivalent to::
1697
1698 / -> t2 -> t4 \
1699 t1 -> t6
1700 \ -> t3 -> t5 /
1701
1702 .. code-block:: python
1703
1704 t1.set_downstream(t2)
1705 t1.set_downstream(t3)
1706 t2.set_downstream(t4)
1707 t3.set_downstream(t5)
1708 t4.set_downstream(t6)
1709 t5.set_downstream(t6)
1710
1711 Using task-decorated functions aka XComArgs:
1712
1713 .. code-block:: python
1714
1715 chain(x1(), [x2(), x3()], [x4(), x5()], x6())
1716
1717 is equivalent to::
1718
1719 / -> x2 -> x4 \
1720 x1 -> x6
1721 \ -> x3 -> x5 /
1722
1723 .. code-block:: python
1724
1725 x1 = x1()
1726 x2 = x2()
1727 x3 = x3()
1728 x4 = x4()
1729 x5 = x5()
1730 x6 = x6()
1731 x1.set_downstream(x2)
1732 x1.set_downstream(x3)
1733 x2.set_downstream(x4)
1734 x3.set_downstream(x5)
1735 x4.set_downstream(x6)
1736 x5.set_downstream(x6)
1737
1738 Using TaskGroups:
1739
1740 .. code-block:: python
1741
1742 chain(t1, task_group1, task_group2, t2)
1743
1744 t1.set_downstream(task_group1)
1745 task_group1.set_downstream(task_group2)
1746 task_group2.set_downstream(t2)
1747
1748
1749 It is also possible to mix between classic operator/sensor, EdgeModifiers, XComArg, and TaskGroups:
1750
1751 .. code-block:: python
1752
1753 chain(t1, [Label("branch one"), Label("branch two")], [x1(), x2()], task_group1, x3())
1754
1755 is equivalent to::
1756
1757 / "branch one" -> x1 \
1758 t1 -> task_group1 -> x3
1759 \ "branch two" -> x2 /
1760
1761 .. code-block:: python
1762
1763 x1 = x1()
1764 x2 = x2()
1765 x3 = x3()
1766 label1 = Label("branch one")
1767 label2 = Label("branch two")
1768 t1.set_downstream(label1)
1769 label1.set_downstream(x1)
1770 t2.set_downstream(label2)
1771 label2.set_downstream(x2)
1772 x1.set_downstream(task_group1)
1773 x2.set_downstream(task_group1)
1774 task_group1.set_downstream(x3)
1775
1776 # or
1777
1778 x1 = x1()
1779 x2 = x2()
1780 x3 = x3()
1781 t1.set_downstream(x1, edge_modifier=Label("branch one"))
1782 t1.set_downstream(x2, edge_modifier=Label("branch two"))
1783 x1.set_downstream(task_group1)
1784 x2.set_downstream(task_group1)
1785 task_group1.set_downstream(x3)
1786
1787
1788 :param tasks: Individual and/or list of tasks, EdgeModifiers, XComArgs, or TaskGroups to set dependencies
1789 """
1790 for up_task, down_task in zip(tasks, tasks[1:]):
1791 if isinstance(up_task, DependencyMixin):
1792 up_task.set_downstream(down_task)
1793 continue
1794 if isinstance(down_task, DependencyMixin):
1795 down_task.set_upstream(up_task)
1796 continue
1797 if not isinstance(up_task, Sequence) or not isinstance(down_task, Sequence):
1798 raise TypeError(f"Chain not supported between instances of {type(up_task)} and {type(down_task)}")
1799 up_task_list = up_task
1800 down_task_list = down_task
1801 if len(up_task_list) != len(down_task_list):
1802 raise ValueError(
1803 f"Chain not supported for different length Iterable. "
1804 f"Got {len(up_task_list)} and {len(down_task_list)}."
1805 )
1806 for up_t, down_t in zip(up_task_list, down_task_list):
1807 up_t.set_downstream(down_t)
1808
1809
1810def cross_downstream(
1811 from_tasks: Sequence[DependencyMixin],
1812 to_tasks: DependencyMixin | Sequence[DependencyMixin],
1813):
1814 r"""
1815 Set downstream dependencies for all tasks in from_tasks to all tasks in to_tasks.
1816
1817 Using classic operators/sensors:
1818
1819 .. code-block:: python
1820
1821 cross_downstream(from_tasks=[t1, t2, t3], to_tasks=[t4, t5, t6])
1822
1823 is equivalent to::
1824
1825 t1 ---> t4
1826 \ /
1827 t2 -X -> t5
1828 / \
1829 t3 ---> t6
1830
1831 .. code-block:: python
1832
1833 t1.set_downstream(t4)
1834 t1.set_downstream(t5)
1835 t1.set_downstream(t6)
1836 t2.set_downstream(t4)
1837 t2.set_downstream(t5)
1838 t2.set_downstream(t6)
1839 t3.set_downstream(t4)
1840 t3.set_downstream(t5)
1841 t3.set_downstream(t6)
1842
1843 Using task-decorated functions aka XComArgs:
1844
1845 .. code-block:: python
1846
1847 cross_downstream(from_tasks=[x1(), x2(), x3()], to_tasks=[x4(), x5(), x6()])
1848
1849 is equivalent to::
1850
1851 x1 ---> x4
1852 \ /
1853 x2 -X -> x5
1854 / \
1855 x3 ---> x6
1856
1857 .. code-block:: python
1858
1859 x1 = x1()
1860 x2 = x2()
1861 x3 = x3()
1862 x4 = x4()
1863 x5 = x5()
1864 x6 = x6()
1865 x1.set_downstream(x4)
1866 x1.set_downstream(x5)
1867 x1.set_downstream(x6)
1868 x2.set_downstream(x4)
1869 x2.set_downstream(x5)
1870 x2.set_downstream(x6)
1871 x3.set_downstream(x4)
1872 x3.set_downstream(x5)
1873 x3.set_downstream(x6)
1874
1875 It is also possible to mix between classic operator/sensor and XComArg tasks:
1876
1877 .. code-block:: python
1878
1879 cross_downstream(from_tasks=[t1, x2(), t3], to_tasks=[x1(), t2, x3()])
1880
1881 is equivalent to::
1882
1883 t1 ---> x1
1884 \ /
1885 x2 -X -> t2
1886 / \
1887 t3 ---> x3
1888
1889 .. code-block:: python
1890
1891 x1 = x1()
1892 x2 = x2()
1893 x3 = x3()
1894 t1.set_downstream(x1)
1895 t1.set_downstream(t2)
1896 t1.set_downstream(x3)
1897 x2.set_downstream(x1)
1898 x2.set_downstream(t2)
1899 x2.set_downstream(x3)
1900 t3.set_downstream(x1)
1901 t3.set_downstream(t2)
1902 t3.set_downstream(x3)
1903
1904 :param from_tasks: List of tasks or XComArgs to start from.
1905 :param to_tasks: List of tasks or XComArgs to set as downstream dependencies.
1906 """
1907 for task in from_tasks:
1908 task.set_downstream(to_tasks)
1909
1910
1911def chain_linear(*elements: DependencyMixin | Sequence[DependencyMixin]):
1912 """
1913 Simplify task dependency definition.
1914
1915 E.g.: suppose you want precedence like so::
1916
1917 ╭─op2─╮ ╭─op4─╮
1918 op1─┤ ├─├─op5─┤─op7
1919 ╰-op3─╯ ╰-op6─╯
1920
1921 Then you can accomplish like so::
1922
1923 chain_linear(op1, [op2, op3], [op4, op5, op6], op7)
1924
1925 :param elements: a list of operators / lists of operators
1926 """
1927 if not elements:
1928 raise ValueError("No tasks provided; nothing to do.")
1929 prev_elem = None
1930 deps_set = False
1931 for curr_elem in elements:
1932 if isinstance(curr_elem, EdgeModifier):
1933 raise ValueError("Labels are not supported by chain_linear")
1934 if prev_elem is not None:
1935 for task in prev_elem:
1936 task >> curr_elem
1937 if not deps_set:
1938 deps_set = True
1939 prev_elem = [curr_elem] if isinstance(curr_elem, DependencyMixin) else curr_elem
1940 if not deps_set:
1941 raise ValueError("No dependencies were set. Did you forget to expand with `*`?")