1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements. See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership. The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License. You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied. See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18from __future__ import annotations
19
20import abc
21import asyncio
22import collections.abc
23import contextlib
24import copy
25import inspect
26import sys
27import warnings
28from asyncio import AbstractEventLoop
29from collections.abc import Callable, Collection, Generator, Iterable, Mapping, Sequence
30from contextvars import ContextVar
31from dataclasses import dataclass, field
32from datetime import datetime, timedelta
33from enum import Enum
34from functools import total_ordering, wraps
35from types import FunctionType
36from typing import TYPE_CHECKING, Any, ClassVar, Final, NoReturn, TypeVar, cast
37
38import attrs
39
40from airflow.sdk import TriggerRule, timezone
41from airflow.sdk._shared.secrets_masker import redact
42from airflow.sdk.definitions._internal.abstractoperator import (
43 DEFAULT_EMAIL_ON_FAILURE,
44 DEFAULT_EMAIL_ON_RETRY,
45 DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST,
46 DEFAULT_OWNER,
47 DEFAULT_POOL_NAME,
48 DEFAULT_POOL_SLOTS,
49 DEFAULT_PRIORITY_WEIGHT,
50 DEFAULT_QUEUE,
51 DEFAULT_RETRIES,
52 DEFAULT_RETRY_DELAY,
53 DEFAULT_TASK_EXECUTION_TIMEOUT,
54 DEFAULT_TRIGGER_RULE,
55 DEFAULT_WAIT_FOR_PAST_DEPENDS_BEFORE_SKIPPING,
56 DEFAULT_WEIGHT_RULE,
57 AbstractOperator,
58 DependencyMixin,
59 TaskStateChangeCallback,
60)
61from airflow.sdk.definitions._internal.decorators import fixup_decorator_warning_stack
62from airflow.sdk.definitions._internal.node import validate_key
63from airflow.sdk.definitions._internal.setup_teardown import SetupTeardownContext
64from airflow.sdk.definitions._internal.types import NOTSET, validate_instance_args
65from airflow.sdk.definitions.edges import EdgeModifier
66from airflow.sdk.definitions.mappedoperator import OperatorPartial, validate_mapping_kwargs
67from airflow.sdk.definitions.param import ParamsDict
68from airflow.sdk.exceptions import RemovedInAirflow4Warning
69
70# Databases do not support arbitrary precision integers, so we need to limit the range of priority weights.
71# postgres: -2147483648 to +2147483647 (see https://www.postgresql.org/docs/current/datatype-numeric.html)
72# mysql: -2147483648 to +2147483647 (see https://dev.mysql.com/doc/refman/8.4/en/integer-types.html)
73# sqlite: -9223372036854775808 to +9223372036854775807 (see https://sqlite.org/datatype3.html)
74DB_SAFE_MINIMUM = -2147483648
75DB_SAFE_MAXIMUM = 2147483647
76
77
78def db_safe_priority(priority_weight: int) -> int:
79 """Convert priority weight to a safe value for the database."""
80 return max(DB_SAFE_MINIMUM, min(DB_SAFE_MAXIMUM, priority_weight))
81
82
83C = TypeVar("C", bound=Callable)
84T = TypeVar("T", bound=FunctionType)
85
86if TYPE_CHECKING:
87 from types import ClassMethodDescriptorType
88
89 import jinja2
90 from typing_extensions import Self
91
92 from airflow.sdk.bases.operatorlink import BaseOperatorLink
93 from airflow.sdk.definitions.context import Context
94 from airflow.sdk.definitions.dag import DAG
95 from airflow.sdk.definitions.operator_resources import Resources
96 from airflow.sdk.definitions.taskgroup import TaskGroup
97 from airflow.sdk.definitions.xcom_arg import XComArg
98 from airflow.serialization.enums import DagAttributeTypes
99 from airflow.task.priority_strategy import PriorityWeightStrategy
100 from airflow.triggers.base import BaseTrigger, StartTriggerArgs
101
102 TaskPreExecuteHook = Callable[[Context], None]
103 TaskPostExecuteHook = Callable[[Context, Any], None]
104
105__all__ = [
106 "BaseAsyncOperator",
107 "BaseOperator",
108 "chain",
109 "chain_linear",
110 "cross_downstream",
111]
112
113
114class TriggerFailureReason(str, Enum):
115 """
116 Reasons for trigger failures.
117
118 Internal use only.
119
120 :meta private:
121 """
122
123 TRIGGER_TIMEOUT = "Trigger timeout"
124 TRIGGER_FAILURE = "Trigger failure"
125
126
127TRIGGER_FAIL_REPR = "__fail__"
128"""String value to represent trigger failure.
129
130Internal use only.
131
132:meta private:
133"""
134
135
136def _get_parent_defaults(dag: DAG | None, task_group: TaskGroup | None) -> tuple[dict, ParamsDict]:
137 if not dag:
138 return {}, ParamsDict()
139 dag_args = copy.copy(dag.default_args)
140 dag_params = copy.deepcopy(dag.params)
141 dag_params._fill_missing_param_source("dag")
142 if task_group:
143 if task_group.default_args and not isinstance(task_group.default_args, collections.abc.Mapping):
144 raise TypeError("default_args must be a mapping")
145 dag_args.update(task_group.default_args)
146 return dag_args, dag_params
147
148
149def get_merged_defaults(
150 dag: DAG | None,
151 task_group: TaskGroup | None,
152 task_params: collections.abc.MutableMapping | None,
153 task_default_args: dict | None,
154) -> tuple[dict, ParamsDict]:
155 args, params = _get_parent_defaults(dag, task_group)
156 if task_params:
157 if not isinstance(task_params, collections.abc.Mapping):
158 raise TypeError(f"params must be a mapping, got {type(task_params)}")
159
160 task_params = ParamsDict(task_params)
161 task_params._fill_missing_param_source("task")
162 params.update(task_params)
163
164 if task_default_args:
165 if not isinstance(task_default_args, collections.abc.Mapping):
166 raise TypeError(f"default_args must be a mapping, got {type(task_params)}")
167 args.update(task_default_args)
168 with contextlib.suppress(KeyError):
169 if params_from_default_args := ParamsDict(task_default_args["params"] or {}):
170 params_from_default_args._fill_missing_param_source("task")
171 params.update(params_from_default_args)
172
173 return args, params
174
175
176def parse_retries(retries: Any) -> int | None:
177 if retries is None:
178 return 0
179 if type(retries) == int: # noqa: E721
180 return retries
181 try:
182 parsed_retries = int(retries)
183 except (TypeError, ValueError):
184 raise RuntimeError(f"'retries' type must be int, not {type(retries).__name__}")
185 return parsed_retries
186
187
188def coerce_timedelta(value: float | timedelta, *, key: str | None = None) -> timedelta:
189 if isinstance(value, timedelta):
190 return value
191 return timedelta(seconds=value)
192
193
194def coerce_resources(resources: dict[str, Any] | None) -> Resources | None:
195 if resources is None:
196 return None
197 from airflow.sdk.definitions.operator_resources import Resources
198
199 return Resources(**resources)
200
201
202@contextlib.contextmanager
203def event_loop() -> Generator[AbstractEventLoop]:
204 new_event_loop = False
205 loop = None
206 try:
207 try:
208 loop = asyncio.get_event_loop()
209 if loop.is_closed():
210 raise RuntimeError
211 except RuntimeError:
212 loop = asyncio.new_event_loop()
213 asyncio.set_event_loop(loop)
214 new_event_loop = True
215 yield loop
216 finally:
217 if new_event_loop and loop is not None:
218 with contextlib.suppress(AttributeError):
219 loop.close()
220 asyncio.set_event_loop(None)
221
222
223class _PartialDescriptor:
224 """A descriptor that guards against ``.partial`` being called on Task objects."""
225
226 class_method: ClassMethodDescriptorType | None = None
227
228 def __get__(
229 self, obj: BaseOperator, cls: type[BaseOperator] | None = None
230 ) -> Callable[..., OperatorPartial]:
231 # Call this "partial" so it looks nicer in stack traces.
232 def partial(**kwargs):
233 raise TypeError("partial can only be called on Operator classes, not Tasks themselves")
234
235 if obj is not None:
236 return partial
237 return self.class_method.__get__(cls, cls)
238
239
240OPERATOR_DEFAULTS: dict[str, Any] = {
241 "allow_nested_operators": True,
242 "depends_on_past": False,
243 "email_on_failure": DEFAULT_EMAIL_ON_FAILURE,
244 "email_on_retry": DEFAULT_EMAIL_ON_RETRY,
245 "execution_timeout": DEFAULT_TASK_EXECUTION_TIMEOUT,
246 # "executor": DEFAULT_EXECUTOR,
247 "executor_config": {},
248 "ignore_first_depends_on_past": DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST,
249 "inlets": [],
250 "map_index_template": None,
251 "on_execute_callback": [],
252 "on_failure_callback": [],
253 "on_retry_callback": [],
254 "on_skipped_callback": [],
255 "on_success_callback": [],
256 "outlets": [],
257 "owner": DEFAULT_OWNER,
258 "pool_slots": DEFAULT_POOL_SLOTS,
259 "priority_weight": DEFAULT_PRIORITY_WEIGHT,
260 "queue": DEFAULT_QUEUE,
261 "retries": DEFAULT_RETRIES,
262 "retry_delay": DEFAULT_RETRY_DELAY,
263 "retry_exponential_backoff": 0,
264 "trigger_rule": DEFAULT_TRIGGER_RULE,
265 "wait_for_past_depends_before_skipping": DEFAULT_WAIT_FOR_PAST_DEPENDS_BEFORE_SKIPPING,
266 "wait_for_downstream": False,
267 "weight_rule": DEFAULT_WEIGHT_RULE,
268}
269
270
271# This is what handles the actual mapping.
272
273if TYPE_CHECKING:
274
275 def partial(
276 operator_class: type[BaseOperator],
277 *,
278 task_id: str,
279 dag: DAG | None = None,
280 task_group: TaskGroup | None = None,
281 start_date: datetime = ...,
282 end_date: datetime = ...,
283 owner: str = ...,
284 email: None | str | Iterable[str] = ...,
285 params: collections.abc.MutableMapping | None = None,
286 resources: dict[str, Any] | None = ...,
287 trigger_rule: str = ...,
288 depends_on_past: bool = ...,
289 ignore_first_depends_on_past: bool = ...,
290 wait_for_past_depends_before_skipping: bool = ...,
291 wait_for_downstream: bool = ...,
292 retries: int | None = ...,
293 queue: str = ...,
294 pool: str = ...,
295 pool_slots: int = ...,
296 execution_timeout: timedelta | None = ...,
297 max_retry_delay: None | timedelta | float = ...,
298 retry_delay: timedelta | float = ...,
299 retry_exponential_backoff: float = ...,
300 priority_weight: int = ...,
301 weight_rule: str | PriorityWeightStrategy = ...,
302 sla: timedelta | None = ...,
303 map_index_template: str | None = ...,
304 max_active_tis_per_dag: int | None = ...,
305 max_active_tis_per_dagrun: int | None = ...,
306 on_execute_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = ...,
307 on_failure_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = ...,
308 on_success_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = ...,
309 on_retry_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = ...,
310 on_skipped_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = ...,
311 run_as_user: str | None = ...,
312 executor: str | None = ...,
313 executor_config: dict | None = ...,
314 inlets: Any | None = ...,
315 outlets: Any | None = ...,
316 doc: str | None = ...,
317 doc_md: str | None = ...,
318 doc_json: str | None = ...,
319 doc_yaml: str | None = ...,
320 doc_rst: str | None = ...,
321 task_display_name: str | None = ...,
322 logger_name: str | None = ...,
323 allow_nested_operators: bool = True,
324 **kwargs,
325 ) -> OperatorPartial: ...
326else:
327
328 def partial(
329 operator_class: type[BaseOperator],
330 *,
331 task_id: str,
332 dag: DAG | None = None,
333 task_group: TaskGroup | None = None,
334 params: collections.abc.MutableMapping | None = None,
335 **kwargs,
336 ):
337 from airflow.sdk.definitions._internal.contextmanager import DagContext, TaskGroupContext
338
339 validate_mapping_kwargs(operator_class, "partial", kwargs)
340
341 dag = dag or DagContext.get_current()
342 if dag:
343 task_group = task_group or TaskGroupContext.get_current(dag)
344 if task_group:
345 task_id = task_group.child_id(task_id)
346
347 # Merge Dag and task group level defaults into user-supplied values.
348 dag_default_args, partial_params = get_merged_defaults(
349 dag=dag,
350 task_group=task_group,
351 task_params=params,
352 task_default_args=kwargs.pop("default_args", None),
353 )
354
355 # Create partial_kwargs from args and kwargs
356 partial_kwargs: dict[str, Any] = {
357 "task_id": task_id,
358 "dag": dag,
359 "task_group": task_group,
360 **kwargs,
361 }
362
363 # Inject Dag-level default args into args provided to this function.
364 # Most of the default args will be retrieved during unmapping; here we
365 # only ensure base properties are correctly set for the scheduler.
366 partial_kwargs.update(
367 (k, v)
368 for k, v in dag_default_args.items()
369 if k not in partial_kwargs and k in BaseOperator.__init__._BaseOperatorMeta__param_names
370 )
371
372 # Fill fields not provided by the user with default values.
373 partial_kwargs.update((k, v) for k, v in OPERATOR_DEFAULTS.items() if k not in partial_kwargs)
374
375 # Post-process arguments. Should be kept in sync with _TaskDecorator.expand().
376 if "task_concurrency" in kwargs: # Reject deprecated option.
377 raise TypeError("unexpected argument: task_concurrency")
378 if start_date := partial_kwargs.get("start_date", None):
379 partial_kwargs["start_date"] = timezone.convert_to_utc(start_date)
380 if end_date := partial_kwargs.get("end_date", None):
381 partial_kwargs["end_date"] = timezone.convert_to_utc(end_date)
382 if partial_kwargs["pool_slots"] < 1:
383 dag_str = ""
384 if dag:
385 dag_str = f" in dag {dag.dag_id}"
386 raise ValueError(f"pool slots for {task_id}{dag_str} cannot be less than 1")
387 if retries := partial_kwargs.get("retries"):
388 partial_kwargs["retries"] = BaseOperator._convert_retries(retries)
389 partial_kwargs["retry_delay"] = BaseOperator._convert_retry_delay(partial_kwargs["retry_delay"])
390 partial_kwargs["max_retry_delay"] = BaseOperator._convert_max_retry_delay(
391 partial_kwargs.get("max_retry_delay", None)
392 )
393
394 for k in ("execute", "failure", "success", "retry", "skipped"):
395 partial_kwargs[attr] = _collect_from_input(partial_kwargs.get(attr := f"on_{k}_callback"))
396
397 return OperatorPartial(
398 operator_class=operator_class,
399 kwargs=partial_kwargs,
400 params=partial_params,
401 )
402
403
404class ExecutorSafeguard:
405 """
406 The ExecutorSafeguard decorator.
407
408 Checks if the execute method of an operator isn't manually called outside
409 the TaskInstance as we want to avoid bad mixing between decorated and
410 classic operators.
411 """
412
413 test_mode: ClassVar[bool] = False
414 tracker: ClassVar[ContextVar[BaseOperator]] = ContextVar("ExecutorSafeguard_sentinel")
415 sentinel_value: ClassVar[object] = object()
416
417 @classmethod
418 def decorator(cls, func):
419 @wraps(func)
420 def wrapper(self, *args, **kwargs):
421 sentinel_key = f"{self.__class__.__name__}__sentinel"
422 sentinel = kwargs.pop(sentinel_key, None)
423
424 with contextlib.ExitStack() as stack:
425 if sentinel is cls.sentinel_value:
426 token = cls.tracker.set(self)
427 sentinel = self
428 stack.callback(cls.tracker.reset, token)
429 else:
430 # No sentinel passed in, maybe the subclass execute did have it passed?
431 sentinel = cls.tracker.get(None)
432
433 if not cls.test_mode and sentinel is not self:
434 message = f"{self.__class__.__name__}.{func.__name__} cannot be called outside of the Task Runner!"
435 if not self.allow_nested_operators:
436 raise RuntimeError(message)
437 self.log.warning(message)
438
439 # Now that we've logged, set sentinel so that `super()` calls don't log again
440 token = cls.tracker.set(self)
441 stack.callback(cls.tracker.reset, token)
442
443 return func(self, *args, **kwargs)
444
445 return wrapper
446
447
448if "airflow.configuration" in sys.modules:
449 # Don't try and import it if it's not already loaded
450 from airflow.sdk.configuration import conf
451
452 ExecutorSafeguard.test_mode = conf.getboolean("core", "unit_test_mode")
453
454
455def _collect_from_input(value_or_values: None | C | Collection[C]) -> list[C]:
456 if not value_or_values:
457 return []
458 if isinstance(value_or_values, Collection):
459 return list(value_or_values)
460 return [value_or_values]
461
462
463class BaseOperatorMeta(abc.ABCMeta):
464 """Metaclass of BaseOperator."""
465
466 @classmethod
467 def _apply_defaults(cls, func: T) -> T:
468 """
469 Look for an argument named "default_args", and fill the unspecified arguments from it.
470
471 Since python2.* isn't clear about which arguments are missing when
472 calling a function, and that this can be quite confusing with multi-level
473 inheritance and argument defaults, this decorator also alerts with
474 specific information about the missing arguments.
475 """
476 # Cache inspect.signature for the wrapper closure to avoid calling it
477 # at every decorated invocation. This is separate sig_cache created
478 # per decoration, i.e. each function decorated using apply_defaults will
479 # have a different sig_cache.
480 sig_cache = inspect.signature(func)
481 non_variadic_params = {
482 name: param
483 for (name, param) in sig_cache.parameters.items()
484 if param.name != "self" and param.kind not in (param.VAR_POSITIONAL, param.VAR_KEYWORD)
485 }
486 non_optional_args = {
487 name
488 for name, param in non_variadic_params.items()
489 if param.default == param.empty and name != "task_id"
490 }
491
492 fixup_decorator_warning_stack(func)
493
494 @wraps(func)
495 def apply_defaults(self: BaseOperator, *args: Any, **kwargs: Any) -> Any:
496 from airflow.sdk.definitions._internal.contextmanager import DagContext, TaskGroupContext
497
498 if args:
499 raise TypeError("Use keyword arguments when initializing operators")
500
501 instantiated_from_mapped = kwargs.pop(
502 "_airflow_from_mapped",
503 getattr(self, "_BaseOperator__from_mapped", False),
504 )
505
506 dag: DAG | None = kwargs.get("dag")
507 if dag is None:
508 dag = DagContext.get_current()
509 if dag is not None:
510 kwargs["dag"] = dag
511
512 task_group: TaskGroup | None = kwargs.get("task_group")
513 if dag and not task_group:
514 task_group = TaskGroupContext.get_current(dag)
515 if task_group is not None:
516 kwargs["task_group"] = task_group
517
518 default_args, merged_params = get_merged_defaults(
519 dag=dag,
520 task_group=task_group,
521 task_params=kwargs.pop("params", None),
522 task_default_args=kwargs.pop("default_args", None),
523 )
524
525 for arg in sig_cache.parameters:
526 if arg not in kwargs and arg in default_args:
527 kwargs[arg] = default_args[arg]
528
529 missing_args = non_optional_args.difference(kwargs)
530 if len(missing_args) == 1:
531 raise TypeError(f"missing keyword argument {missing_args.pop()!r}")
532 if missing_args:
533 display = ", ".join(repr(a) for a in sorted(missing_args))
534 raise TypeError(f"missing keyword arguments {display}")
535
536 if merged_params:
537 kwargs["params"] = merged_params
538
539 hook = getattr(self, "_hook_apply_defaults", None)
540 if hook:
541 args, kwargs = hook(**kwargs, default_args=default_args)
542 default_args = kwargs.pop("default_args", {})
543
544 if not hasattr(self, "_BaseOperator__init_kwargs"):
545 object.__setattr__(self, "_BaseOperator__init_kwargs", {})
546 object.__setattr__(self, "_BaseOperator__from_mapped", instantiated_from_mapped)
547
548 result = func(self, **kwargs, default_args=default_args)
549
550 # Store the args passed to init -- we need them to support task.map serialization!
551 self._BaseOperator__init_kwargs.update(kwargs) # type: ignore
552
553 # Set upstream task defined by XComArgs passed to template fields of the operator.
554 # BUT: only do this _ONCE_, not once for each class in the hierarchy
555 if not instantiated_from_mapped and func == self.__init__.__wrapped__: # type: ignore[misc]
556 self._set_xcomargs_dependencies()
557 # Mark instance as instantiated so that future attr setting updates xcomarg-based deps.
558 object.__setattr__(self, "_BaseOperator__instantiated", True)
559
560 return result
561
562 apply_defaults.__non_optional_args = non_optional_args # type: ignore
563 apply_defaults.__param_names = set(non_variadic_params) # type: ignore
564
565 return cast("T", apply_defaults)
566
567 def __new__(cls, name, bases, namespace, **kwargs):
568 execute_method = namespace.get("execute")
569 if callable(execute_method) and not getattr(execute_method, "__isabstractmethod__", False):
570 namespace["execute"] = ExecutorSafeguard.decorator(execute_method)
571 new_cls = super().__new__(cls, name, bases, namespace, **kwargs)
572 with contextlib.suppress(KeyError):
573 # Update the partial descriptor with the class method, so it calls the actual function
574 # (but let subclasses override it if they need to)
575 partial_desc = vars(new_cls)["partial"]
576 if isinstance(partial_desc, _PartialDescriptor):
577 partial_desc.class_method = classmethod(partial)
578
579 # We patch `__init__` only if the class defines it.
580 first_superclass = new_cls.mro()[1]
581 if new_cls.__init__ is not first_superclass.__init__:
582 new_cls.__init__ = cls._apply_defaults(new_cls.__init__)
583
584 return new_cls
585
586
587# TODO: The following mapping is used to validate that the arguments passed to the BaseOperator are of the
588# correct type. This is a temporary solution until we find a more sophisticated method for argument
589# validation. One potential method is to use `get_type_hints` from the typing module. However, this is not
590# fully compatible with future annotations for Python versions below 3.10. Once we require a minimum Python
591# version that supports `get_type_hints` effectively or find a better approach, we can replace this
592# manual type-checking method.
593BASEOPERATOR_ARGS_EXPECTED_TYPES = {
594 "task_id": str,
595 "email": (str, Sequence),
596 "email_on_retry": bool,
597 "email_on_failure": bool,
598 "retries": int,
599 "retry_exponential_backoff": (int, float),
600 "depends_on_past": bool,
601 "ignore_first_depends_on_past": bool,
602 "wait_for_past_depends_before_skipping": bool,
603 "wait_for_downstream": bool,
604 "priority_weight": int,
605 "queue": str,
606 "pool": str,
607 "pool_slots": int,
608 "trigger_rule": str,
609 "run_as_user": str,
610 "task_concurrency": int,
611 "map_index_template": str,
612 "max_active_tis_per_dag": int,
613 "max_active_tis_per_dagrun": int,
614 "executor": str,
615 "do_xcom_push": bool,
616 "multiple_outputs": bool,
617 "doc": str,
618 "doc_md": str,
619 "doc_json": str,
620 "doc_yaml": str,
621 "doc_rst": str,
622 "task_display_name": str,
623 "logger_name": str,
624 "allow_nested_operators": bool,
625 "start_date": datetime,
626 "end_date": datetime,
627}
628
629
630# Note: BaseOperator is defined as a dataclass, and not an attrs class as we do too much metaprogramming in
631# here (metaclass, custom `__setattr__` behaviour) and this fights with attrs too much to make it worth it.
632#
633# To future reader: if you want to try and make this a "normal" attrs class, go ahead and attempt it. If you
634# get nowhere leave your record here for the next poor soul and what problems you ran in to.
635#
636# @ashb, 2024/10/14
637# - "Can't combine custom __setattr__ with on_setattr hooks"
638# - Setting class-wide `define(on_setarrs=...)` isn't called for non-attrs subclasses
639@total_ordering
640@dataclass(repr=False)
641class BaseOperator(AbstractOperator, metaclass=BaseOperatorMeta):
642 r"""
643 Abstract base class for all operators.
644
645 Since operators create objects that become nodes in the Dag, BaseOperator
646 contains many recursive methods for Dag crawling behavior. To derive from
647 this class, you are expected to override the constructor and the 'execute'
648 method.
649
650 Operators derived from this class should perform or trigger certain tasks
651 synchronously (wait for completion). Example of operators could be an
652 operator that runs a Pig job (PigOperator), a sensor operator that
653 waits for a partition to land in Hive (HiveSensorOperator), or one that
654 moves data from Hive to MySQL (Hive2MySqlOperator). Instances of these
655 operators (tasks) target specific operations, running specific scripts,
656 functions or data transfers.
657
658 This class is abstract and shouldn't be instantiated. Instantiating a
659 class derived from this one results in the creation of a task object,
660 which ultimately becomes a node in Dag objects. Task dependencies should
661 be set by using the set_upstream and/or set_downstream methods.
662
663 :param task_id: a unique, meaningful id for the task
664 :param owner: the owner of the task. Using a meaningful description
665 (e.g. user/person/team/role name) to clarify ownership is recommended.
666 :param email: the 'to' email address(es) used in email alerts. This can be a
667 single email or multiple ones. Multiple addresses can be specified as a
668 comma or semicolon separated string or by passing a list of strings. (deprecated)
669 :param email_on_retry: Indicates whether email alerts should be sent when a
670 task is retried (deprecated)
671 :param email_on_failure: Indicates whether email alerts should be sent when
672 a task failed (deprecated)
673 :param retries: the number of retries that should be performed before
674 failing the task
675 :param retry_delay: delay between retries, can be set as ``timedelta`` or
676 ``float`` seconds, which will be converted into ``timedelta``,
677 the default is ``timedelta(seconds=300)``.
678 :param retry_exponential_backoff: multiplier for exponential backoff between retries.
679 Set to 0 to disable (constant delay). Set to 2.0 for standard exponential backoff
680 (delay doubles with each retry). For example, with retry_delay=4min and
681 retry_exponential_backoff=5, retries occur after 4min, 20min, 100min, etc.
682 :param max_retry_delay: maximum delay interval between retries, can be set as
683 ``timedelta`` or ``float`` seconds, which will be converted into ``timedelta``.
684 :param start_date: The ``start_date`` for the task, determines
685 the ``logical_date`` for the first task instance. The best practice
686 is to have the start_date rounded
687 to your Dag's ``schedule_interval``. Daily jobs have their start_date
688 some day at 00:00:00, hourly jobs have their start_date at 00:00
689 of a specific hour. Note that Airflow simply looks at the latest
690 ``logical_date`` and adds the ``schedule_interval`` to determine
691 the next ``logical_date``. It is also very important
692 to note that different tasks' dependencies
693 need to line up in time. If task A depends on task B and their
694 start_date are offset in a way that their logical_date don't line
695 up, A's dependencies will never be met. If you are looking to delay
696 a task, for example running a daily task at 2AM, look into the
697 ``TimeSensor`` and ``TimeDeltaSensor``. We advise against using
698 dynamic ``start_date`` and recommend using fixed ones. Read the
699 FAQ entry about start_date for more information.
700 :param end_date: if specified, the scheduler won't go beyond this date
701 :param depends_on_past: when set to true, task instances will run
702 sequentially and only if the previous instance has succeeded or has been skipped.
703 The task instance for the start_date is allowed to run.
704 :param wait_for_past_depends_before_skipping: when set to true, if the task instance
705 should be marked as skipped, and depends_on_past is true, the ti will stay on None state
706 waiting the task of the previous run
707 :param wait_for_downstream: when set to true, an instance of task
708 X will wait for tasks immediately downstream of the previous instance
709 of task X to finish successfully or be skipped before it runs. This is useful if the
710 different instances of a task X alter the same asset, and this asset
711 is used by tasks downstream of task X. Note that depends_on_past
712 is forced to True wherever wait_for_downstream is used. Also note that
713 only tasks *immediately* downstream of the previous task instance are waited
714 for; the statuses of any tasks further downstream are ignored.
715 :param dag: a reference to the dag the task is attached to (if any)
716 :param priority_weight: priority weight of this task against other task.
717 This allows the executor to trigger higher priority tasks before
718 others when things get backed up. Set priority_weight as a higher
719 number for more important tasks.
720 As not all database engines support 64-bit integers, values are capped with 32-bit.
721 Valid range is from -2,147,483,648 to 2,147,483,647.
722 :param weight_rule: weighting method used for the effective total
723 priority weight of the task. Options are:
724 ``{ downstream | upstream | absolute }`` default is ``downstream``
725 When set to ``downstream`` the effective weight of the task is the
726 aggregate sum of all downstream descendants. As a result, upstream
727 tasks will have higher weight and will be scheduled more aggressively
728 when using positive weight values. This is useful when you have
729 multiple dag run instances and desire to have all upstream tasks to
730 complete for all runs before each dag can continue processing
731 downstream tasks. When set to ``upstream`` the effective weight is the
732 aggregate sum of all upstream ancestors. This is the opposite where
733 downstream tasks have higher weight and will be scheduled more
734 aggressively when using positive weight values. This is useful when you
735 have multiple dag run instances and prefer to have each dag complete
736 before starting upstream tasks of other dags. When set to
737 ``absolute``, the effective weight is the exact ``priority_weight``
738 specified without additional weighting. You may want to do this when
739 you know exactly what priority weight each task should have.
740 Additionally, when set to ``absolute``, there is bonus effect of
741 significantly speeding up the task creation process as for very large
742 Dags. Options can be set as string or using the constants defined in
743 the static class ``airflow.utils.WeightRule``.
744 Irrespective of the weight rule, resulting priority values are capped with 32-bit.
745 |experimental|
746 Since 2.9.0, Airflow allows to define custom priority weight strategy,
747 by creating a subclass of
748 ``airflow.task.priority_strategy.PriorityWeightStrategy`` and registering
749 in a plugin, then providing the class path or the class instance via
750 ``weight_rule`` parameter. The custom priority weight strategy will be
751 used to calculate the effective total priority weight of the task instance.
752 :param queue: which queue to target when running this job. Not
753 all executors implement queue management, the CeleryExecutor
754 does support targeting specific queues.
755 :param pool: the slot pool this task should run in, slot pools are a
756 way to limit concurrency for certain tasks
757 :param pool_slots: the number of pool slots this task should use (>= 1)
758 Values less than 1 are not allowed.
759 :param sla: DEPRECATED - The SLA feature is removed in Airflow 3.0, to be replaced with a
760 new implementation in Airflow >=3.1.
761 :param execution_timeout: max time allowed for the execution of
762 this task instance, if it goes beyond it will raise and fail.
763 :param on_failure_callback: a function or list of functions to be called when a task instance
764 of this task fails. a context dictionary is passed as a single
765 parameter to this function. Context contains references to related
766 objects to the task instance and is documented under the macros
767 section of the API.
768 :param on_execute_callback: much like the ``on_failure_callback`` except
769 that it is executed right before the task is executed.
770 :param on_retry_callback: much like the ``on_failure_callback`` except
771 that it is executed when retries occur.
772 :param on_success_callback: much like the ``on_failure_callback`` except
773 that it is executed when the task succeeds.
774 :param on_skipped_callback: much like the ``on_failure_callback`` except
775 that it is executed when skipped occur; this callback will be called only if AirflowSkipException get raised.
776 Explicitly it is NOT called if a task is not started to be executed because of a preceding branching
777 decision in the Dag or a trigger rule which causes execution to skip so that the task execution
778 is never scheduled.
779 :param pre_execute: a function to be called immediately before task
780 execution, receiving a context dictionary; raising an exception will
781 prevent the task from being executed.
782 :param post_execute: a function to be called immediately after task
783 execution, receiving a context dictionary and task result; raising an
784 exception will prevent the task from succeeding.
785 :param trigger_rule: defines the rule by which dependencies are applied
786 for the task to get triggered. Options are:
787 ``{ all_success | all_failed | all_done | all_skipped | one_success | one_done |
788 one_failed | none_failed | none_failed_min_one_success | none_skipped | always}``
789 default is ``all_success``. Options can be set as string or
790 using the constants defined in the static class
791 ``airflow.utils.TriggerRule``
792 :param resources: A map of resource parameter names (the argument names of the
793 Resources constructor) to their values.
794 :param run_as_user: unix username to impersonate while running the task
795 :param max_active_tis_per_dag: When set, a task will be able to limit the concurrent
796 runs across logical_dates.
797 :param max_active_tis_per_dagrun: When set, a task will be able to limit the concurrent
798 task instances per Dag run.
799 :param executor: Which executor to target when running this task. NOT YET SUPPORTED
800 :param executor_config: Additional task-level configuration parameters that are
801 interpreted by a specific executor. Parameters are namespaced by the name of
802 executor.
803
804 **Example**: to run this task in a specific docker container through
805 the KubernetesExecutor ::
806
807 MyOperator(..., executor_config={"KubernetesExecutor": {"image": "myCustomDockerImage"}})
808
809 :param do_xcom_push: if True, an XCom is pushed containing the Operator's
810 result
811 :param multiple_outputs: if True and do_xcom_push is True, pushes multiple XComs, one for each
812 key in the returned dictionary result. If False and do_xcom_push is True, pushes a single XCom.
813 :param task_group: The TaskGroup to which the task should belong. This is typically provided when not
814 using a TaskGroup as a context manager.
815 :param doc: Add documentation or notes to your Task objects that is visible in
816 Task Instance details View in the Webserver
817 :param doc_md: Add documentation (in Markdown format) or notes to your Task objects
818 that is visible in Task Instance details View in the Webserver
819 :param doc_rst: Add documentation (in RST format) or notes to your Task objects
820 that is visible in Task Instance details View in the Webserver
821 :param doc_json: Add documentation (in JSON format) or notes to your Task objects
822 that is visible in Task Instance details View in the Webserver
823 :param doc_yaml: Add documentation (in YAML format) or notes to your Task objects
824 that is visible in Task Instance details View in the Webserver
825 :param task_display_name: The display name of the task which appears on the UI.
826 :param logger_name: Name of the logger used by the Operator to emit logs.
827 If set to `None` (default), the logger name will fall back to
828 `airflow.task.operators.{class.__module__}.{class.__name__}` (e.g. HttpOperator will have
829 *airflow.task.operators.airflow.providers.http.operators.http.HttpOperator* as logger).
830 :param allow_nested_operators: if True, when an operator is executed within another one a warning message
831 will be logged. If False, then an exception will be raised if the operator is badly used (e.g. nested
832 within another one). In future releases of Airflow this parameter will be removed and an exception
833 will always be thrown when operators are nested within each other (default is True).
834
835 **Example**: example of a bad operator mixin usage::
836
837 @task(provide_context=True)
838 def say_hello_world(**context):
839 hello_world_task = BashOperator(
840 task_id="hello_world_task",
841 bash_command="python -c \"print('Hello, world!')\"",
842 dag=dag,
843 )
844 hello_world_task.execute(context)
845 """
846
847 task_id: str
848 owner: str = DEFAULT_OWNER
849 email: str | Sequence[str] | None = None
850 email_on_retry: bool = DEFAULT_EMAIL_ON_RETRY
851 email_on_failure: bool = DEFAULT_EMAIL_ON_FAILURE
852 retries: int | None = DEFAULT_RETRIES
853 retry_delay: timedelta = DEFAULT_RETRY_DELAY
854 retry_exponential_backoff: float = 0
855 max_retry_delay: timedelta | float | None = None
856 start_date: datetime | None = None
857 end_date: datetime | None = None
858 depends_on_past: bool = False
859 ignore_first_depends_on_past: bool = DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST
860 wait_for_past_depends_before_skipping: bool = DEFAULT_WAIT_FOR_PAST_DEPENDS_BEFORE_SKIPPING
861 wait_for_downstream: bool = False
862
863 # At execution_time this becomes a normal dict
864 params: ParamsDict | dict = field(default_factory=ParamsDict)
865 default_args: dict | None = None
866 priority_weight: int = DEFAULT_PRIORITY_WEIGHT
867 weight_rule: PriorityWeightStrategy | str = field(default=DEFAULT_WEIGHT_RULE)
868 queue: str = DEFAULT_QUEUE
869 pool: str = DEFAULT_POOL_NAME
870 pool_slots: int = DEFAULT_POOL_SLOTS
871 execution_timeout: timedelta | None = DEFAULT_TASK_EXECUTION_TIMEOUT
872 on_execute_callback: Sequence[TaskStateChangeCallback] = ()
873 on_failure_callback: Sequence[TaskStateChangeCallback] = ()
874 on_success_callback: Sequence[TaskStateChangeCallback] = ()
875 on_retry_callback: Sequence[TaskStateChangeCallback] = ()
876 on_skipped_callback: Sequence[TaskStateChangeCallback] = ()
877 _pre_execute_hook: TaskPreExecuteHook | None = None
878 _post_execute_hook: TaskPostExecuteHook | None = None
879 trigger_rule: TriggerRule = DEFAULT_TRIGGER_RULE
880 resources: dict[str, Any] | None = None
881 run_as_user: str | None = None
882 task_concurrency: int | None = None
883 map_index_template: str | None = None
884 max_active_tis_per_dag: int | None = None
885 max_active_tis_per_dagrun: int | None = None
886 executor: str | None = None
887 executor_config: dict | None = None
888 do_xcom_push: bool = True
889 multiple_outputs: bool = False
890 inlets: list[Any] = field(default_factory=list)
891 outlets: list[Any] = field(default_factory=list)
892 task_group: TaskGroup | None = None
893 doc: str | None = None
894 doc_md: str | None = None
895 doc_json: str | None = None
896 doc_yaml: str | None = None
897 doc_rst: str | None = None
898 _task_display_name: str | None = None
899 logger_name: str | None = None
900 allow_nested_operators: bool = True
901
902 is_setup: bool = False
903 is_teardown: bool = False
904
905 # TODO: Task-SDK: Make these ClassVar[]?
906 template_fields: Collection[str] = ()
907 template_ext: Sequence[str] = ()
908
909 template_fields_renderers: ClassVar[dict[str, str]] = {}
910
911 operator_extra_links: Collection[BaseOperatorLink] = ()
912
913 # Defines the color in the UI
914 ui_color: str = "#fff"
915 ui_fgcolor: str = "#000"
916
917 partial: Callable[..., OperatorPartial] = _PartialDescriptor() # type: ignore
918
919 _dag: DAG | None = field(init=False, default=None)
920
921 # Make this optional so the type matches the one define in LoggingMixin
922 _log_config_logger_name: str | None = field(default="airflow.task.operators", init=False)
923 _logger_name: str | None = None
924
925 # The _serialized_fields are lazily loaded when get_serialized_fields() method is called
926 __serialized_fields: ClassVar[frozenset[str] | None] = None
927
928 _comps: ClassVar[set[str]] = {
929 "task_id",
930 "dag_id",
931 "owner",
932 "email",
933 "email_on_retry",
934 "retry_delay",
935 "retry_exponential_backoff",
936 "max_retry_delay",
937 "start_date",
938 "end_date",
939 "depends_on_past",
940 "wait_for_downstream",
941 "priority_weight",
942 "execution_timeout",
943 "has_on_execute_callback",
944 "has_on_failure_callback",
945 "has_on_success_callback",
946 "has_on_retry_callback",
947 "has_on_skipped_callback",
948 "do_xcom_push",
949 "multiple_outputs",
950 "allow_nested_operators",
951 "executor",
952 }
953
954 # If True, the Rendered Template fields will be overwritten in DB after execution
955 # This is useful for Taskflow decorators that modify the template fields during execution like
956 # @task.bash decorator.
957 overwrite_rtif_after_execution: bool = False
958
959 # If True then the class constructor was called
960 __instantiated: bool = False
961 # List of args as passed to `init()`, after apply_defaults() has been updated. Used to "recreate" the task
962 # when mapping
963 # Set via the metaclass
964 __init_kwargs: dict[str, Any] = field(init=False)
965
966 # Set to True before calling execute method
967 _lock_for_execution: bool = False
968
969 # Set to True for an operator instantiated by a mapped operator.
970 __from_mapped: bool = False
971
972 start_trigger_args: StartTriggerArgs | None = None
973 start_from_trigger: bool = False
974
975 # base list which includes all the attrs that don't need deep copy.
976 _base_operator_shallow_copy_attrs: Final[tuple[str, ...]] = (
977 "user_defined_macros",
978 "user_defined_filters",
979 "params",
980 )
981
982 # each operator should override this class attr for shallow copy attrs.
983 shallow_copy_attrs: Sequence[str] = ()
984
985 def __setattr__(self: BaseOperator, key: str, value: Any):
986 if converter := getattr(self, f"_convert_{key}", None):
987 value = converter(value)
988 super().__setattr__(key, value)
989 if self.__from_mapped or self._lock_for_execution:
990 return # Skip any custom behavior for validation and during execute.
991 if key in self.__init_kwargs:
992 self.__init_kwargs[key] = value
993 if self.__instantiated and key in self.template_fields:
994 # Resolve upstreams set by assigning an XComArg after initializing
995 # an operator, example:
996 # op = BashOperator()
997 # op.bash_command = "sleep 1"
998 self._set_xcomargs_dependency(key, value)
999
1000 def __init__(
1001 self,
1002 *,
1003 task_id: str,
1004 owner: str = DEFAULT_OWNER,
1005 email: str | Sequence[str] | None = None,
1006 email_on_retry: bool = DEFAULT_EMAIL_ON_RETRY,
1007 email_on_failure: bool = DEFAULT_EMAIL_ON_FAILURE,
1008 retries: int | None = DEFAULT_RETRIES,
1009 retry_delay: timedelta | float = DEFAULT_RETRY_DELAY,
1010 retry_exponential_backoff: float = 0,
1011 max_retry_delay: timedelta | float | None = None,
1012 start_date: datetime | None = None,
1013 end_date: datetime | None = None,
1014 depends_on_past: bool = False,
1015 ignore_first_depends_on_past: bool = DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST,
1016 wait_for_past_depends_before_skipping: bool = DEFAULT_WAIT_FOR_PAST_DEPENDS_BEFORE_SKIPPING,
1017 wait_for_downstream: bool = False,
1018 dag: DAG | None = None,
1019 params: collections.abc.MutableMapping[str, Any] | None = None,
1020 default_args: dict | None = None,
1021 priority_weight: int = DEFAULT_PRIORITY_WEIGHT,
1022 weight_rule: str | PriorityWeightStrategy = DEFAULT_WEIGHT_RULE,
1023 queue: str = DEFAULT_QUEUE,
1024 pool: str | None = None,
1025 pool_slots: int = DEFAULT_POOL_SLOTS,
1026 sla: timedelta | None = None,
1027 execution_timeout: timedelta | None = DEFAULT_TASK_EXECUTION_TIMEOUT,
1028 on_execute_callback: None | TaskStateChangeCallback | Collection[TaskStateChangeCallback] = None,
1029 on_failure_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = None,
1030 on_success_callback: None | TaskStateChangeCallback | Collection[TaskStateChangeCallback] = None,
1031 on_retry_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = None,
1032 on_skipped_callback: None | TaskStateChangeCallback | Collection[TaskStateChangeCallback] = None,
1033 pre_execute: TaskPreExecuteHook | None = None,
1034 post_execute: TaskPostExecuteHook | None = None,
1035 trigger_rule: str = DEFAULT_TRIGGER_RULE,
1036 resources: dict[str, Any] | None = None,
1037 run_as_user: str | None = None,
1038 map_index_template: str | None = None,
1039 max_active_tis_per_dag: int | None = None,
1040 max_active_tis_per_dagrun: int | None = None,
1041 executor: str | None = None,
1042 executor_config: dict | None = None,
1043 do_xcom_push: bool = True,
1044 multiple_outputs: bool = False,
1045 inlets: Any | None = None,
1046 outlets: Any | None = None,
1047 task_group: TaskGroup | None = None,
1048 doc: str | None = None,
1049 doc_md: str | None = None,
1050 doc_json: str | None = None,
1051 doc_yaml: str | None = None,
1052 doc_rst: str | None = None,
1053 task_display_name: str | None = None,
1054 logger_name: str | None = None,
1055 allow_nested_operators: bool = True,
1056 **kwargs: Any,
1057 ):
1058 # Note: Metaclass handles passing in the Dag/TaskGroup from active context manager, if any
1059
1060 # Only apply task_group prefix if this operator was not created from a mapped operator
1061 # Mapped operators already have the prefix applied during their creation
1062 if task_group and not self.__from_mapped:
1063 self.task_id = task_group.child_id(task_id)
1064 task_group.add(self)
1065 else:
1066 self.task_id = task_id
1067
1068 super().__init__()
1069 self.task_group = task_group
1070
1071 kwargs.pop("_airflow_mapped_validation_only", None)
1072 if kwargs:
1073 raise TypeError(
1074 f"Invalid arguments were passed to {self.__class__.__name__} (task_id: {task_id}). "
1075 f"Invalid arguments were:\n**kwargs: {redact(kwargs)}",
1076 )
1077 validate_key(self.task_id)
1078
1079 self.owner = owner
1080 self.email = email
1081 self.email_on_retry = email_on_retry
1082 self.email_on_failure = email_on_failure
1083
1084 if email is not None:
1085 warnings.warn(
1086 "Setting email on a task is deprecated; please migrate to SmtpNotifier.",
1087 RemovedInAirflow4Warning,
1088 stacklevel=2,
1089 )
1090 if email and email_on_retry is not None:
1091 warnings.warn(
1092 "Setting email_on_retry on a task is deprecated; please migrate to SmtpNotifier.",
1093 RemovedInAirflow4Warning,
1094 stacklevel=2,
1095 )
1096 if email and email_on_failure is not None:
1097 warnings.warn(
1098 "Setting email_on_failure on a task is deprecated; please migrate to SmtpNotifier.",
1099 RemovedInAirflow4Warning,
1100 stacklevel=2,
1101 )
1102
1103 if execution_timeout is not None and not isinstance(execution_timeout, timedelta):
1104 raise ValueError(
1105 f"execution_timeout must be timedelta object but passed as type: {type(execution_timeout)}"
1106 )
1107 self.execution_timeout = execution_timeout
1108
1109 self.on_execute_callback = _collect_from_input(on_execute_callback)
1110 self.on_failure_callback = _collect_from_input(on_failure_callback)
1111 self.on_success_callback = _collect_from_input(on_success_callback)
1112 self.on_retry_callback = _collect_from_input(on_retry_callback)
1113 self.on_skipped_callback = _collect_from_input(on_skipped_callback)
1114 self._pre_execute_hook = pre_execute
1115 self._post_execute_hook = post_execute
1116
1117 self.start_date = timezone.convert_to_utc(start_date)
1118 self.end_date = timezone.convert_to_utc(end_date)
1119 self.executor = executor
1120 self.executor_config = executor_config or {}
1121 self.run_as_user = run_as_user
1122 # TODO:
1123 # self.retries = parse_retries(retries)
1124 self.retries = retries
1125 self.queue = queue
1126 self.pool = DEFAULT_POOL_NAME if pool is None else pool
1127 self.pool_slots = pool_slots
1128 if self.pool_slots < 1:
1129 dag_str = f" in dag {dag.dag_id}" if dag else ""
1130 raise ValueError(f"pool slots for {self.task_id}{dag_str} cannot be less than 1")
1131 if sla is not None:
1132 warnings.warn(
1133 "The SLA feature is removed in Airflow 3.0, replaced with Deadline Alerts in >=3.1",
1134 stacklevel=2,
1135 )
1136
1137 try:
1138 TriggerRule(trigger_rule)
1139 except ValueError:
1140 raise ValueError(
1141 f"The trigger_rule must be one of {[rule.value for rule in TriggerRule]},"
1142 f"'{dag.dag_id if dag else ''}.{task_id}'; received '{trigger_rule}'."
1143 )
1144
1145 self.trigger_rule: TriggerRule = TriggerRule(trigger_rule)
1146
1147 self.depends_on_past: bool = depends_on_past
1148 self.ignore_first_depends_on_past: bool = ignore_first_depends_on_past
1149 self.wait_for_past_depends_before_skipping: bool = wait_for_past_depends_before_skipping
1150 self.wait_for_downstream: bool = wait_for_downstream
1151 if wait_for_downstream:
1152 self.depends_on_past = True
1153
1154 # Converted by setattr
1155 self.retry_delay = retry_delay # type: ignore[assignment]
1156 self.retry_exponential_backoff = retry_exponential_backoff
1157 if max_retry_delay is not None:
1158 self.max_retry_delay = max_retry_delay
1159
1160 self.resources = resources
1161
1162 self.params = ParamsDict(params)
1163
1164 self.priority_weight = priority_weight
1165 self.weight_rule = weight_rule
1166
1167 self.max_active_tis_per_dag: int | None = max_active_tis_per_dag
1168 self.max_active_tis_per_dagrun: int | None = max_active_tis_per_dagrun
1169 self.do_xcom_push: bool = do_xcom_push
1170 self.map_index_template: str | None = map_index_template
1171 self.multiple_outputs: bool = multiple_outputs
1172
1173 self.doc_md = doc_md
1174 self.doc_json = doc_json
1175 self.doc_yaml = doc_yaml
1176 self.doc_rst = doc_rst
1177 self.doc = doc
1178
1179 self._task_display_name = task_display_name
1180
1181 self.allow_nested_operators = allow_nested_operators
1182
1183 self._logger_name = logger_name
1184
1185 # Lineage
1186 self.inlets = _collect_from_input(inlets)
1187 self.outlets = _collect_from_input(outlets)
1188
1189 if isinstance(self.template_fields, str):
1190 warnings.warn(
1191 f"The `template_fields` value for {self.task_type} is a string "
1192 "but should be a list or tuple of string. Wrapping it in a list for execution. "
1193 f"Please update {self.task_type} accordingly.",
1194 UserWarning,
1195 stacklevel=2,
1196 )
1197 self.template_fields = [self.template_fields]
1198
1199 self.is_setup = False
1200 self.is_teardown = False
1201
1202 if SetupTeardownContext.active:
1203 SetupTeardownContext.update_context_map(self)
1204
1205 # We set self.dag right at the end as `_convert_dag` calls `dag.add_task` for us, and we need all the
1206 # other properties to be set at that point
1207 if dag is not None:
1208 self.dag = dag
1209
1210 validate_instance_args(self, BASEOPERATOR_ARGS_EXPECTED_TYPES)
1211
1212 # Ensure priority_weight is within the valid range
1213 self.priority_weight = db_safe_priority(self.priority_weight)
1214
1215 def __eq__(self, other):
1216 if type(self) is type(other):
1217 # Use getattr() instead of __dict__ as __dict__ doesn't return
1218 # correct values for properties.
1219 return all(getattr(self, c, None) == getattr(other, c, None) for c in self._comps)
1220 return False
1221
1222 def __ne__(self, other):
1223 return not self == other
1224
1225 def __hash__(self):
1226 hash_components = [type(self)]
1227 for component in self._comps:
1228 val = getattr(self, component, None)
1229 try:
1230 hash(val)
1231 hash_components.append(val)
1232 except TypeError:
1233 hash_components.append(repr(val))
1234 return hash(tuple(hash_components))
1235
1236 # /Composing Operators ---------------------------------------------
1237
1238 def __gt__(self, other):
1239 """
1240 Return [Operator] > [Outlet].
1241
1242 If other is an attr annotated object it is set as an outlet of this Operator.
1243 """
1244 if not isinstance(other, Iterable):
1245 other = [other]
1246
1247 for obj in other:
1248 if not attrs.has(obj):
1249 raise TypeError(f"Left hand side ({obj}) is not an outlet")
1250 self.add_outlets(other)
1251
1252 return self
1253
1254 def __lt__(self, other):
1255 """
1256 Return [Inlet] > [Operator] or [Operator] < [Inlet].
1257
1258 If other is an attr annotated object it is set as an inlet to this operator.
1259 """
1260 if not isinstance(other, Iterable):
1261 other = [other]
1262
1263 for obj in other:
1264 if not attrs.has(obj):
1265 raise TypeError(f"{obj} cannot be an inlet")
1266 self.add_inlets(other)
1267
1268 return self
1269
1270 def __deepcopy__(self, memo: dict[int, Any]):
1271 # Hack sorting double chained task lists by task_id to avoid hitting
1272 # max_depth on deepcopy operations.
1273 sys.setrecursionlimit(5000) # TODO fix this in a better way
1274
1275 cls = self.__class__
1276 result = cls.__new__(cls)
1277 memo[id(self)] = result
1278
1279 shallow_copy = tuple(cls.shallow_copy_attrs) + cls._base_operator_shallow_copy_attrs
1280
1281 for k, v_org in self.__dict__.items():
1282 if k not in shallow_copy:
1283 v = copy.deepcopy(v_org, memo)
1284 else:
1285 v = copy.copy(v_org)
1286
1287 # Bypass any setters, and set it on the object directly. This works since we are cloning ourself so
1288 # we know the type is already fine
1289 result.__dict__[k] = v
1290 return result
1291
1292 def __getstate__(self):
1293 state = dict(self.__dict__)
1294 if "_log" in state:
1295 del state["_log"]
1296
1297 return state
1298
1299 def __setstate__(self, state):
1300 self.__dict__ = state
1301
1302 def add_inlets(self, inlets: Iterable[Any]):
1303 """Set inlets to this operator."""
1304 self.inlets.extend(inlets)
1305
1306 def add_outlets(self, outlets: Iterable[Any]):
1307 """Define the outlets of this operator."""
1308 self.outlets.extend(outlets)
1309
1310 def get_dag(self) -> DAG | None:
1311 return self._dag
1312
1313 @property
1314 def dag(self) -> DAG:
1315 """Returns the Operator's Dag if set, otherwise raises an error."""
1316 if dag := self._dag:
1317 return dag
1318 raise RuntimeError(f"Operator {self} has not been assigned to a Dag yet")
1319
1320 @dag.setter
1321 def dag(self, dag: DAG | None) -> None:
1322 """Operators can be assigned to one Dag, one time. Repeat assignments to that same Dag are ok."""
1323 self._dag = dag
1324
1325 def _convert__dag(self, dag: DAG | None) -> DAG | None:
1326 # Called automatically by __setattr__ method
1327 from airflow.sdk.definitions.dag import DAG
1328
1329 if dag is None:
1330 return dag
1331
1332 if not isinstance(dag, DAG):
1333 raise TypeError(f"Expected dag; received {dag.__class__.__name__}")
1334 if self._dag is not None and self._dag is not dag:
1335 raise ValueError(f"The dag assigned to {self} can not be changed.")
1336
1337 if self.__from_mapped:
1338 pass # Don't add to dag -- the mapped task takes the place.
1339 elif dag.task_dict.get(self.task_id) is not self:
1340 dag.add_task(self)
1341 return dag
1342
1343 @staticmethod
1344 def _convert_retries(retries: Any) -> int | None:
1345 if retries is None:
1346 return 0
1347 if type(retries) == int: # noqa: E721
1348 return retries
1349 try:
1350 parsed_retries = int(retries)
1351 except (TypeError, ValueError):
1352 raise TypeError(f"'retries' type must be int, not {type(retries).__name__}")
1353 return parsed_retries
1354
1355 @staticmethod
1356 def _convert_timedelta(value: float | timedelta | None) -> timedelta | None:
1357 if value is None or isinstance(value, timedelta):
1358 return value
1359 return timedelta(seconds=value)
1360
1361 _convert_retry_delay = _convert_timedelta
1362 _convert_max_retry_delay = _convert_timedelta
1363
1364 @staticmethod
1365 def _convert_resources(resources: dict[str, Any] | None) -> Resources | None:
1366 if resources is None:
1367 return None
1368
1369 from airflow.sdk.definitions.operator_resources import Resources
1370
1371 if isinstance(resources, Resources):
1372 return resources
1373
1374 return Resources(**resources)
1375
1376 def _convert_is_setup(self, value: bool) -> bool:
1377 """
1378 Setter for is_setup property.
1379
1380 :meta private:
1381 """
1382 if self.is_teardown and value:
1383 raise ValueError(f"Cannot mark task '{self.task_id}' as setup; task is already a teardown.")
1384 return value
1385
1386 def _convert_is_teardown(self, value: bool) -> bool:
1387 if self.is_setup and value:
1388 raise ValueError(f"Cannot mark task '{self.task_id}' as teardown; task is already a setup.")
1389 return value
1390
1391 @property
1392 def task_display_name(self) -> str:
1393 return self._task_display_name or self.task_id
1394
1395 def has_dag(self):
1396 """Return True if the Operator has been assigned to a Dag."""
1397 return self._dag is not None
1398
1399 def _set_xcomargs_dependencies(self) -> None:
1400 from airflow.sdk.definitions.xcom_arg import XComArg
1401
1402 for f in self.template_fields:
1403 arg = getattr(self, f, NOTSET)
1404 if arg is not NOTSET:
1405 XComArg.apply_upstream_relationship(self, arg)
1406
1407 def _set_xcomargs_dependency(self, field: str, newvalue: Any) -> None:
1408 """
1409 Resolve upstream dependencies of a task.
1410
1411 In this way passing an ``XComArg`` as value for a template field
1412 will result in creating upstream relation between two tasks.
1413
1414 **Example**: ::
1415
1416 with DAG(...):
1417 generate_content = GenerateContentOperator(task_id="generate_content")
1418 send_email = EmailOperator(..., html_content=generate_content.output)
1419
1420 # This is equivalent to
1421 with DAG(...):
1422 generate_content = GenerateContentOperator(task_id="generate_content")
1423 send_email = EmailOperator(
1424 ..., html_content="{{ task_instance.xcom_pull('generate_content') }}"
1425 )
1426 generate_content >> send_email
1427
1428 """
1429 from airflow.sdk.definitions.xcom_arg import XComArg
1430
1431 if field not in self.template_fields:
1432 return
1433 XComArg.apply_upstream_relationship(self, newvalue)
1434
1435 def on_kill(self) -> None:
1436 """
1437 Override this method to clean up subprocesses when a task instance gets killed.
1438
1439 Any use of the threading, subprocess or multiprocessing module within an
1440 operator needs to be cleaned up, or it will leave ghost processes behind.
1441 """
1442
1443 def __repr__(self):
1444 return f"<Task({self.task_type}): {self.task_id}>"
1445
1446 @property
1447 def operator_class(self) -> type[BaseOperator]: # type: ignore[override]
1448 return self.__class__
1449
1450 @property
1451 def task_type(self) -> str:
1452 """@property: type of the task."""
1453 return self.__class__.__name__
1454
1455 @property
1456 def operator_name(self) -> str:
1457 """@property: use a more friendly display name for the operator, if set."""
1458 try:
1459 return self.custom_operator_name # type: ignore
1460 except AttributeError:
1461 return self.task_type
1462
1463 @property
1464 def roots(self) -> list[BaseOperator]:
1465 """Required by DAGNode."""
1466 return [self]
1467
1468 @property
1469 def leaves(self) -> list[BaseOperator]:
1470 """Required by DAGNode."""
1471 return [self]
1472
1473 @property
1474 def output(self) -> XComArg:
1475 """Returns reference to XCom pushed by current operator."""
1476 from airflow.sdk.definitions.xcom_arg import XComArg
1477
1478 return XComArg(operator=self)
1479
1480 @classmethod
1481 def get_serialized_fields(cls):
1482 """Stringified Dags and operators contain exactly these fields."""
1483 if not cls.__serialized_fields:
1484 from airflow.sdk.definitions._internal.contextmanager import DagContext
1485
1486 # make sure the following "fake" task is not added to current active
1487 # dag in context, otherwise, it will result in
1488 # `RuntimeError: dictionary changed size during iteration`
1489 # Exception in SerializedDAG.serialize_dag() call.
1490 DagContext.push(None)
1491 cls.__serialized_fields = frozenset(
1492 vars(BaseOperator(task_id="test")).keys()
1493 - {
1494 "upstream_task_ids",
1495 "default_args",
1496 "dag",
1497 "_dag",
1498 "label",
1499 "_BaseOperator__instantiated",
1500 "_BaseOperator__init_kwargs",
1501 "_BaseOperator__from_mapped",
1502 "on_failure_fail_dagrun",
1503 "task_group",
1504 "_task_type",
1505 "operator_extra_links",
1506 "on_execute_callback",
1507 "on_failure_callback",
1508 "on_success_callback",
1509 "on_retry_callback",
1510 "on_skipped_callback",
1511 }
1512 | { # Class level defaults, or `@property` need to be added to this list
1513 "start_date",
1514 "end_date",
1515 "task_type",
1516 "ui_color",
1517 "ui_fgcolor",
1518 "template_ext",
1519 "template_fields",
1520 "template_fields_renderers",
1521 "params",
1522 "is_setup",
1523 "is_teardown",
1524 "on_failure_fail_dagrun",
1525 "map_index_template",
1526 "start_trigger_args",
1527 "_needs_expansion",
1528 "start_from_trigger",
1529 "max_retry_delay",
1530 "has_on_execute_callback",
1531 "has_on_failure_callback",
1532 "has_on_success_callback",
1533 "has_on_retry_callback",
1534 "has_on_skipped_callback",
1535 }
1536 )
1537 DagContext.pop()
1538
1539 return cls.__serialized_fields
1540
1541 def prepare_for_execution(self) -> Self:
1542 """Lock task for execution to disable custom action in ``__setattr__`` and return a copy."""
1543 other = copy.copy(self)
1544 other._lock_for_execution = True
1545 return other
1546
1547 def serialize_for_task_group(self) -> tuple[DagAttributeTypes, Any]:
1548 """Serialize; required by DAGNode."""
1549 from airflow.serialization.enums import DagAttributeTypes
1550
1551 return DagAttributeTypes.OP, self.task_id
1552
1553 def unmap(self, resolve: None | Mapping[str, Any]) -> Self:
1554 """
1555 Get the "normal" operator from the current operator.
1556
1557 Since a BaseOperator is not mapped to begin with, this simply returns
1558 the original operator.
1559
1560 :meta private:
1561 """
1562 return self
1563
1564 def expand_start_trigger_args(self, *, context: Context) -> StartTriggerArgs | None:
1565 """
1566 Get the start_trigger_args value of the current abstract operator.
1567
1568 Since a BaseOperator is not mapped to begin with, this simply returns
1569 the original value of start_trigger_args.
1570
1571 :meta private:
1572 """
1573 return self.start_trigger_args
1574
1575 def render_template_fields(
1576 self,
1577 context: Context,
1578 jinja_env: jinja2.Environment | None = None,
1579 ) -> None:
1580 """
1581 Template all attributes listed in *self.template_fields*.
1582
1583 This mutates the attributes in-place and is irreversible.
1584
1585 :param context: Context dict with values to apply on content.
1586 :param jinja_env: Jinja's environment to use for rendering.
1587 """
1588 if not jinja_env:
1589 jinja_env = self.get_template_env()
1590 self._do_render_template_fields(self, self.template_fields, context, jinja_env, set())
1591
1592 def pre_execute(self, context: Any):
1593 """Execute right before self.execute() is called."""
1594
1595 def execute(self, context: Context) -> Any:
1596 """
1597 Derive when creating an operator.
1598
1599 The main method to execute the task. Context is the same dictionary used
1600 as when rendering jinja templates.
1601
1602 Refer to get_template_context for more context.
1603 """
1604 raise NotImplementedError()
1605
1606 def post_execute(self, context: Any, result: Any = None):
1607 """
1608 Execute right after self.execute() is called.
1609
1610 It is passed the execution context and any results returned by the operator.
1611 """
1612
1613 def defer(
1614 self,
1615 *,
1616 trigger: BaseTrigger,
1617 method_name: str,
1618 kwargs: dict[str, Any] | None = None,
1619 timeout: timedelta | int | float | None = None,
1620 ) -> NoReturn:
1621 """
1622 Mark this Operator "deferred", suspending its execution until the provided trigger fires an event.
1623
1624 This is achieved by raising a special exception (TaskDeferred)
1625 which is caught in the main _execute_task wrapper. Triggers can send execution back to task or end
1626 the task instance directly. If the trigger will end the task instance itself, ``method_name`` should
1627 be None; otherwise, provide the name of the method that should be used when resuming execution in
1628 the task.
1629 """
1630 from airflow.sdk.exceptions import TaskDeferred
1631
1632 raise TaskDeferred(trigger=trigger, method_name=method_name, kwargs=kwargs, timeout=timeout)
1633
1634 def resume_execution(self, next_method: str, next_kwargs: dict[str, Any] | None, context: Context):
1635 """Entrypoint method called by the Task Runner (instead of execute) when this task is resumed."""
1636 from airflow.sdk.exceptions import TaskDeferralError, TaskDeferralTimeout
1637
1638 if next_kwargs is None:
1639 next_kwargs = {}
1640 # __fail__ is a special signal value for next_method that indicates
1641 # this task was scheduled specifically to fail.
1642
1643 if next_method == TRIGGER_FAIL_REPR:
1644 next_kwargs = next_kwargs or {}
1645 traceback = next_kwargs.get("traceback")
1646 if traceback is not None:
1647 self.log.error("Trigger failed:\n%s", "\n".join(traceback))
1648 if (error := next_kwargs.get("error", "Unknown")) == TriggerFailureReason.TRIGGER_TIMEOUT:
1649 raise TaskDeferralTimeout(error)
1650 raise TaskDeferralError(error)
1651 # Grab the callable off the Operator/Task and add in any kwargs
1652 execute_callable = getattr(self, next_method)
1653 return execute_callable(context, **next_kwargs)
1654
1655 def dry_run(self) -> None:
1656 """Perform dry run for the operator - just render template fields."""
1657 self.log.info("Dry run")
1658 for f in self.template_fields:
1659 try:
1660 content = getattr(self, f)
1661 except AttributeError:
1662 raise AttributeError(
1663 f"{f!r} is configured as a template field "
1664 f"but {self.task_type} does not have this attribute."
1665 )
1666
1667 if content and isinstance(content, str):
1668 self.log.info("Rendering template for %s", f)
1669 self.log.info(content)
1670
1671 @property
1672 def has_on_execute_callback(self) -> bool:
1673 """Return True if the task has execute callbacks."""
1674 return bool(self.on_execute_callback)
1675
1676 @property
1677 def has_on_failure_callback(self) -> bool:
1678 """Return True if the task has failure callbacks."""
1679 return bool(self.on_failure_callback)
1680
1681 @property
1682 def has_on_success_callback(self) -> bool:
1683 """Return True if the task has success callbacks."""
1684 return bool(self.on_success_callback)
1685
1686 @property
1687 def has_on_retry_callback(self) -> bool:
1688 """Return True if the task has retry callbacks."""
1689 return bool(self.on_retry_callback)
1690
1691 @property
1692 def has_on_skipped_callback(self) -> bool:
1693 """Return True if the task has skipped callbacks."""
1694 return bool(self.on_skipped_callback)
1695
1696
1697class BaseAsyncOperator(BaseOperator):
1698 """
1699 Base class for async-capable operators.
1700
1701 As opposed to deferred operators which are executed on the triggerer, async operators are executed
1702 on the worker.
1703 """
1704
1705 @property
1706 def is_async(self) -> bool:
1707 return True
1708
1709 async def aexecute(self, context):
1710 """Async version of execute(). Subclasses should implement this."""
1711 raise NotImplementedError()
1712
1713 def execute(self, context):
1714 """Run `aexecute()` inside an event loop."""
1715 with event_loop() as loop:
1716 if self.execution_timeout:
1717 return loop.run_until_complete(
1718 asyncio.wait_for(
1719 self.aexecute(context),
1720 timeout=self.execution_timeout.total_seconds(),
1721 )
1722 )
1723 return loop.run_until_complete(self.aexecute(context))
1724
1725
1726def chain(*tasks: DependencyMixin | Sequence[DependencyMixin]) -> None:
1727 r"""
1728 Given a number of tasks, builds a dependency chain.
1729
1730 This function accepts values of BaseOperator (aka tasks), EdgeModifiers (aka Labels), XComArg, TaskGroups,
1731 or lists containing any mix of these types (or a mix in the same list). If you want to chain between two
1732 lists you must ensure they have the same length.
1733
1734 Using classic operators/sensors:
1735
1736 .. code-block:: python
1737
1738 chain(t1, [t2, t3], [t4, t5], t6)
1739
1740 is equivalent to::
1741
1742 / -> t2 -> t4 \
1743 t1 -> t6
1744 \ -> t3 -> t5 /
1745
1746 .. code-block:: python
1747
1748 t1.set_downstream(t2)
1749 t1.set_downstream(t3)
1750 t2.set_downstream(t4)
1751 t3.set_downstream(t5)
1752 t4.set_downstream(t6)
1753 t5.set_downstream(t6)
1754
1755 Using task-decorated functions aka XComArgs:
1756
1757 .. code-block:: python
1758
1759 chain(x1(), [x2(), x3()], [x4(), x5()], x6())
1760
1761 is equivalent to::
1762
1763 / -> x2 -> x4 \
1764 x1 -> x6
1765 \ -> x3 -> x5 /
1766
1767 .. code-block:: python
1768
1769 x1 = x1()
1770 x2 = x2()
1771 x3 = x3()
1772 x4 = x4()
1773 x5 = x5()
1774 x6 = x6()
1775 x1.set_downstream(x2)
1776 x1.set_downstream(x3)
1777 x2.set_downstream(x4)
1778 x3.set_downstream(x5)
1779 x4.set_downstream(x6)
1780 x5.set_downstream(x6)
1781
1782 Using TaskGroups:
1783
1784 .. code-block:: python
1785
1786 chain(t1, task_group1, task_group2, t2)
1787
1788 t1.set_downstream(task_group1)
1789 task_group1.set_downstream(task_group2)
1790 task_group2.set_downstream(t2)
1791
1792
1793 It is also possible to mix between classic operator/sensor, EdgeModifiers, XComArg, and TaskGroups:
1794
1795 .. code-block:: python
1796
1797 chain(t1, [Label("branch one"), Label("branch two")], [x1(), x2()], task_group1, x3())
1798
1799 is equivalent to::
1800
1801 / "branch one" -> x1 \
1802 t1 -> task_group1 -> x3
1803 \ "branch two" -> x2 /
1804
1805 .. code-block:: python
1806
1807 x1 = x1()
1808 x2 = x2()
1809 x3 = x3()
1810 label1 = Label("branch one")
1811 label2 = Label("branch two")
1812 t1.set_downstream(label1)
1813 label1.set_downstream(x1)
1814 t2.set_downstream(label2)
1815 label2.set_downstream(x2)
1816 x1.set_downstream(task_group1)
1817 x2.set_downstream(task_group1)
1818 task_group1.set_downstream(x3)
1819
1820 # or
1821
1822 x1 = x1()
1823 x2 = x2()
1824 x3 = x3()
1825 t1.set_downstream(x1, edge_modifier=Label("branch one"))
1826 t1.set_downstream(x2, edge_modifier=Label("branch two"))
1827 x1.set_downstream(task_group1)
1828 x2.set_downstream(task_group1)
1829 task_group1.set_downstream(x3)
1830
1831
1832 :param tasks: Individual and/or list of tasks, EdgeModifiers, XComArgs, or TaskGroups to set dependencies
1833 """
1834 for up_task, down_task in zip(tasks, tasks[1:]):
1835 if isinstance(up_task, DependencyMixin):
1836 up_task.set_downstream(down_task)
1837 continue
1838 if isinstance(down_task, DependencyMixin):
1839 down_task.set_upstream(up_task)
1840 continue
1841 if not isinstance(up_task, Sequence) or not isinstance(down_task, Sequence):
1842 raise TypeError(f"Chain not supported between instances of {type(up_task)} and {type(down_task)}")
1843 up_task_list = up_task
1844 down_task_list = down_task
1845 if len(up_task_list) != len(down_task_list):
1846 raise ValueError(
1847 f"Chain not supported for different length Iterable. "
1848 f"Got {len(up_task_list)} and {len(down_task_list)}."
1849 )
1850 for up_t, down_t in zip(up_task_list, down_task_list):
1851 up_t.set_downstream(down_t)
1852
1853
1854def cross_downstream(
1855 from_tasks: Sequence[DependencyMixin],
1856 to_tasks: DependencyMixin | Sequence[DependencyMixin],
1857):
1858 r"""
1859 Set downstream dependencies for all tasks in from_tasks to all tasks in to_tasks.
1860
1861 Using classic operators/sensors:
1862
1863 .. code-block:: python
1864
1865 cross_downstream(from_tasks=[t1, t2, t3], to_tasks=[t4, t5, t6])
1866
1867 is equivalent to::
1868
1869 t1 ---> t4
1870 \ /
1871 t2 -X -> t5
1872 / \
1873 t3 ---> t6
1874
1875 .. code-block:: python
1876
1877 t1.set_downstream(t4)
1878 t1.set_downstream(t5)
1879 t1.set_downstream(t6)
1880 t2.set_downstream(t4)
1881 t2.set_downstream(t5)
1882 t2.set_downstream(t6)
1883 t3.set_downstream(t4)
1884 t3.set_downstream(t5)
1885 t3.set_downstream(t6)
1886
1887 Using task-decorated functions aka XComArgs:
1888
1889 .. code-block:: python
1890
1891 cross_downstream(from_tasks=[x1(), x2(), x3()], to_tasks=[x4(), x5(), x6()])
1892
1893 is equivalent to::
1894
1895 x1 ---> x4
1896 \ /
1897 x2 -X -> x5
1898 / \
1899 x3 ---> x6
1900
1901 .. code-block:: python
1902
1903 x1 = x1()
1904 x2 = x2()
1905 x3 = x3()
1906 x4 = x4()
1907 x5 = x5()
1908 x6 = x6()
1909 x1.set_downstream(x4)
1910 x1.set_downstream(x5)
1911 x1.set_downstream(x6)
1912 x2.set_downstream(x4)
1913 x2.set_downstream(x5)
1914 x2.set_downstream(x6)
1915 x3.set_downstream(x4)
1916 x3.set_downstream(x5)
1917 x3.set_downstream(x6)
1918
1919 It is also possible to mix between classic operator/sensor and XComArg tasks:
1920
1921 .. code-block:: python
1922
1923 cross_downstream(from_tasks=[t1, x2(), t3], to_tasks=[x1(), t2, x3()])
1924
1925 is equivalent to::
1926
1927 t1 ---> x1
1928 \ /
1929 x2 -X -> t2
1930 / \
1931 t3 ---> x3
1932
1933 .. code-block:: python
1934
1935 x1 = x1()
1936 x2 = x2()
1937 x3 = x3()
1938 t1.set_downstream(x1)
1939 t1.set_downstream(t2)
1940 t1.set_downstream(x3)
1941 x2.set_downstream(x1)
1942 x2.set_downstream(t2)
1943 x2.set_downstream(x3)
1944 t3.set_downstream(x1)
1945 t3.set_downstream(t2)
1946 t3.set_downstream(x3)
1947
1948 :param from_tasks: List of tasks or XComArgs to start from.
1949 :param to_tasks: List of tasks or XComArgs to set as downstream dependencies.
1950 """
1951 for task in from_tasks:
1952 task.set_downstream(to_tasks)
1953
1954
1955def chain_linear(*elements: DependencyMixin | Sequence[DependencyMixin]):
1956 """
1957 Simplify task dependency definition.
1958
1959 E.g.: suppose you want precedence like so::
1960
1961 ╭─op2─╮ ╭─op4─╮
1962 op1─┤ ├─├─op5─┤─op7
1963 ╰-op3─╯ ╰-op6─╯
1964
1965 Then you can accomplish like so::
1966
1967 chain_linear(op1, [op2, op3], [op4, op5, op6], op7)
1968
1969 :param elements: a list of operators / lists of operators
1970 """
1971 if not elements:
1972 raise ValueError("No tasks provided; nothing to do.")
1973 prev_elem = None
1974 deps_set = False
1975 for curr_elem in elements:
1976 if isinstance(curr_elem, EdgeModifier):
1977 raise ValueError("Labels are not supported by chain_linear")
1978 if prev_elem is not None:
1979 for task in prev_elem:
1980 task >> curr_elem
1981 if not deps_set:
1982 deps_set = True
1983 prev_elem = [curr_elem] if isinstance(curr_elem, DependencyMixin) else curr_elem
1984 if not deps_set:
1985 raise ValueError("No dependencies were set. Did you forget to expand with `*`?")