Coverage for /pythoncovmergedfiles/medio/medio/src/airflow/airflow/models/baseoperator.py: 56%
601 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-07 06:35 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-07 06:35 +0000
1#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18"""Base operator for all operators."""
19from __future__ import annotations
21import abc
22import collections
23import collections.abc
24import contextlib
25import copy
26import functools
27import logging
28import sys
29import warnings
30from abc import ABCMeta, abstractmethod
31from datetime import datetime, timedelta
32from inspect import signature
33from types import ClassMethodDescriptorType, FunctionType
34from typing import (
35 TYPE_CHECKING,
36 Any,
37 Callable,
38 ClassVar,
39 Collection,
40 Iterable,
41 List,
42 Sequence,
43 Type,
44 TypeVar,
45 Union,
46 cast,
47)
49import attr
50import pendulum
51from dateutil.relativedelta import relativedelta
52from sqlalchemy.orm import Session
53from sqlalchemy.orm.exc import NoResultFound
55from airflow.configuration import conf
56from airflow.exceptions import AirflowException, DagInvalidTriggerRule, RemovedInAirflow3Warning, TaskDeferred
57from airflow.lineage import apply_lineage, prepare_lineage
58from airflow.models.abstractoperator import (
59 DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST,
60 DEFAULT_OWNER,
61 DEFAULT_POOL_SLOTS,
62 DEFAULT_PRIORITY_WEIGHT,
63 DEFAULT_QUEUE,
64 DEFAULT_RETRIES,
65 DEFAULT_RETRY_DELAY,
66 DEFAULT_TASK_EXECUTION_TIMEOUT,
67 DEFAULT_TRIGGER_RULE,
68 DEFAULT_WAIT_FOR_PAST_DEPENDS_BEFORE_SKIPPING,
69 DEFAULT_WEIGHT_RULE,
70 AbstractOperator,
71 TaskStateChangeCallback,
72)
73from airflow.models.mappedoperator import OperatorPartial, validate_mapping_kwargs
74from airflow.models.param import ParamsDict
75from airflow.models.pool import Pool
76from airflow.models.taskinstance import TaskInstance, clear_task_instances
77from airflow.models.taskmixin import DAGNode, DependencyMixin
78from airflow.serialization.enums import DagAttributeTypes
79from airflow.ti_deps.deps.base_ti_dep import BaseTIDep
80from airflow.ti_deps.deps.not_in_retry_period_dep import NotInRetryPeriodDep
81from airflow.ti_deps.deps.not_previously_skipped_dep import NotPreviouslySkippedDep
82from airflow.ti_deps.deps.prev_dagrun_dep import PrevDagrunDep
83from airflow.ti_deps.deps.trigger_rule_dep import TriggerRuleDep
84from airflow.triggers.base import BaseTrigger
85from airflow.utils import timezone
86from airflow.utils.context import Context
87from airflow.utils.decorators import fixup_decorator_warning_stack
88from airflow.utils.helpers import validate_key
89from airflow.utils.operator_resources import Resources
90from airflow.utils.session import NEW_SESSION, provide_session
91from airflow.utils.setup_teardown import SetupTeardownContext
92from airflow.utils.trigger_rule import TriggerRule
93from airflow.utils.types import NOTSET, ArgNotSet
94from airflow.utils.weight_rule import WeightRule
95from airflow.utils.xcom import XCOM_RETURN_KEY
97if TYPE_CHECKING:
98 import jinja2 # Slow import.
100 from airflow.models.dag import DAG
101 from airflow.models.taskinstancekey import TaskInstanceKey
102 from airflow.models.xcom_arg import XComArg
103 from airflow.utils.task_group import TaskGroup
105ScheduleInterval = Union[str, timedelta, relativedelta]
107TaskPreExecuteHook = Callable[[Context], None]
108TaskPostExecuteHook = Callable[[Context, Any], None]
110T = TypeVar("T", bound=FunctionType)
112logger = logging.getLogger("airflow.models.baseoperator.BaseOperator")
115def parse_retries(retries: Any) -> int | None:
116 if retries is None or isinstance(retries, int):
117 return retries
118 try:
119 parsed_retries = int(retries)
120 except (TypeError, ValueError):
121 raise AirflowException(f"'retries' type must be int, not {type(retries).__name__}")
122 logger.warning("Implicitly converting 'retries' from %r to int", retries)
123 return parsed_retries
126def coerce_timedelta(value: float | timedelta, *, key: str) -> timedelta:
127 if isinstance(value, timedelta):
128 return value
129 logger.debug("%s isn't a timedelta object, assuming secs", key)
130 return timedelta(seconds=value)
133def coerce_resources(resources: dict[str, Any] | None) -> Resources | None:
134 if resources is None:
135 return None
136 return Resources(**resources)
139def _get_parent_defaults(dag: DAG | None, task_group: TaskGroup | None) -> tuple[dict, ParamsDict]:
140 if not dag:
141 return {}, ParamsDict()
142 dag_args = copy.copy(dag.default_args)
143 dag_params = copy.deepcopy(dag.params)
144 if task_group:
145 if task_group.default_args and not isinstance(task_group.default_args, collections.abc.Mapping):
146 raise TypeError("default_args must be a mapping")
147 dag_args.update(task_group.default_args)
148 return dag_args, dag_params
151def get_merged_defaults(
152 dag: DAG | None,
153 task_group: TaskGroup | None,
154 task_params: collections.abc.MutableMapping | None,
155 task_default_args: dict | None,
156) -> tuple[dict, ParamsDict]:
157 args, params = _get_parent_defaults(dag, task_group)
158 if task_params:
159 if not isinstance(task_params, collections.abc.Mapping):
160 raise TypeError("params must be a mapping")
161 params.update(task_params)
162 if task_default_args:
163 if not isinstance(task_default_args, collections.abc.Mapping):
164 raise TypeError("default_args must be a mapping")
165 args.update(task_default_args)
166 with contextlib.suppress(KeyError):
167 params.update(task_default_args["params"] or {})
168 return args, params
171class _PartialDescriptor:
172 """A descriptor that guards against ``.partial`` being called on Task objects."""
174 class_method: ClassMethodDescriptorType | None = None
176 def __get__(
177 self, obj: BaseOperator, cls: type[BaseOperator] | None = None
178 ) -> Callable[..., OperatorPartial]:
179 # Call this "partial" so it looks nicer in stack traces.
180 def partial(**kwargs):
181 raise TypeError("partial can only be called on Operator classes, not Tasks themselves")
183 if obj is not None:
184 return partial
185 return self.class_method.__get__(cls, cls)
188_PARTIAL_DEFAULTS = {
189 "owner": DEFAULT_OWNER,
190 "trigger_rule": DEFAULT_TRIGGER_RULE,
191 "depends_on_past": False,
192 "ignore_first_depends_on_past": DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST,
193 "wait_for_past_depends_before_skipping": DEFAULT_WAIT_FOR_PAST_DEPENDS_BEFORE_SKIPPING,
194 "wait_for_downstream": False,
195 "retries": DEFAULT_RETRIES,
196 "queue": DEFAULT_QUEUE,
197 "pool_slots": DEFAULT_POOL_SLOTS,
198 "execution_timeout": DEFAULT_TASK_EXECUTION_TIMEOUT,
199 "retry_delay": DEFAULT_RETRY_DELAY,
200 "retry_exponential_backoff": False,
201 "priority_weight": DEFAULT_PRIORITY_WEIGHT,
202 "weight_rule": DEFAULT_WEIGHT_RULE,
203 "inlets": [],
204 "outlets": [],
205}
208# This is what handles the actual mapping.
209def partial(
210 operator_class: type[BaseOperator],
211 *,
212 task_id: str,
213 dag: DAG | None = None,
214 task_group: TaskGroup | None = None,
215 start_date: datetime | ArgNotSet = NOTSET,
216 end_date: datetime | ArgNotSet = NOTSET,
217 owner: str | ArgNotSet = NOTSET,
218 email: None | str | Iterable[str] | ArgNotSet = NOTSET,
219 params: collections.abc.MutableMapping | None = None,
220 resources: dict[str, Any] | None | ArgNotSet = NOTSET,
221 trigger_rule: str | ArgNotSet = NOTSET,
222 depends_on_past: bool | ArgNotSet = NOTSET,
223 ignore_first_depends_on_past: bool | ArgNotSet = NOTSET,
224 wait_for_past_depends_before_skipping: bool | ArgNotSet = NOTSET,
225 wait_for_downstream: bool | ArgNotSet = NOTSET,
226 retries: int | None | ArgNotSet = NOTSET,
227 queue: str | ArgNotSet = NOTSET,
228 pool: str | ArgNotSet = NOTSET,
229 pool_slots: int | ArgNotSet = NOTSET,
230 execution_timeout: timedelta | None | ArgNotSet = NOTSET,
231 max_retry_delay: None | timedelta | float | ArgNotSet = NOTSET,
232 retry_delay: timedelta | float | ArgNotSet = NOTSET,
233 retry_exponential_backoff: bool | ArgNotSet = NOTSET,
234 priority_weight: int | ArgNotSet = NOTSET,
235 weight_rule: str | ArgNotSet = NOTSET,
236 sla: timedelta | None | ArgNotSet = NOTSET,
237 max_active_tis_per_dag: int | None | ArgNotSet = NOTSET,
238 max_active_tis_per_dagrun: int | None | ArgNotSet = NOTSET,
239 on_execute_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] | ArgNotSet = NOTSET,
240 on_failure_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] | ArgNotSet = NOTSET,
241 on_success_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] | ArgNotSet = NOTSET,
242 on_retry_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] | ArgNotSet = NOTSET,
243 run_as_user: str | None | ArgNotSet = NOTSET,
244 executor_config: dict | None | ArgNotSet = NOTSET,
245 inlets: Any | None | ArgNotSet = NOTSET,
246 outlets: Any | None | ArgNotSet = NOTSET,
247 doc: str | None | ArgNotSet = NOTSET,
248 doc_md: str | None | ArgNotSet = NOTSET,
249 doc_json: str | None | ArgNotSet = NOTSET,
250 doc_yaml: str | None | ArgNotSet = NOTSET,
251 doc_rst: str | None | ArgNotSet = NOTSET,
252 **kwargs,
253) -> OperatorPartial:
254 from airflow.models.dag import DagContext
255 from airflow.utils.task_group import TaskGroupContext
257 validate_mapping_kwargs(operator_class, "partial", kwargs)
259 dag = dag or DagContext.get_current_dag()
260 if dag:
261 task_group = task_group or TaskGroupContext.get_current_task_group(dag)
262 if task_group:
263 task_id = task_group.child_id(task_id)
265 # Merge DAG and task group level defaults into user-supplied values.
266 dag_default_args, partial_params = get_merged_defaults(
267 dag=dag,
268 task_group=task_group,
269 task_params=params,
270 task_default_args=kwargs.pop("default_args", None),
271 )
273 # Create partial_kwargs from args and kwargs
274 partial_kwargs: dict[str, Any] = {
275 **kwargs,
276 "dag": dag,
277 "task_group": task_group,
278 "task_id": task_id,
279 "start_date": start_date,
280 "end_date": end_date,
281 "owner": owner,
282 "email": email,
283 "trigger_rule": trigger_rule,
284 "depends_on_past": depends_on_past,
285 "ignore_first_depends_on_past": ignore_first_depends_on_past,
286 "wait_for_past_depends_before_skipping": wait_for_past_depends_before_skipping,
287 "wait_for_downstream": wait_for_downstream,
288 "retries": retries,
289 "queue": queue,
290 "pool": pool,
291 "pool_slots": pool_slots,
292 "execution_timeout": execution_timeout,
293 "max_retry_delay": max_retry_delay,
294 "retry_delay": retry_delay,
295 "retry_exponential_backoff": retry_exponential_backoff,
296 "priority_weight": priority_weight,
297 "weight_rule": weight_rule,
298 "sla": sla,
299 "max_active_tis_per_dag": max_active_tis_per_dag,
300 "max_active_tis_per_dagrun": max_active_tis_per_dagrun,
301 "on_execute_callback": on_execute_callback,
302 "on_failure_callback": on_failure_callback,
303 "on_retry_callback": on_retry_callback,
304 "on_success_callback": on_success_callback,
305 "run_as_user": run_as_user,
306 "executor_config": executor_config,
307 "inlets": inlets,
308 "outlets": outlets,
309 "resources": resources,
310 "doc": doc,
311 "doc_json": doc_json,
312 "doc_md": doc_md,
313 "doc_rst": doc_rst,
314 "doc_yaml": doc_yaml,
315 }
317 # Inject DAG-level default args into args provided to this function.
318 partial_kwargs.update((k, v) for k, v in dag_default_args.items() if partial_kwargs.get(k) is NOTSET)
320 # Fill fields not provided by the user with default values.
321 partial_kwargs = {k: _PARTIAL_DEFAULTS.get(k) if v is NOTSET else v for k, v in partial_kwargs.items()}
323 # Post-process arguments. Should be kept in sync with _TaskDecorator.expand().
324 if "task_concurrency" in kwargs: # Reject deprecated option.
325 raise TypeError("unexpected argument: task_concurrency")
326 if partial_kwargs["wait_for_downstream"]:
327 partial_kwargs["depends_on_past"] = True
328 partial_kwargs["start_date"] = timezone.convert_to_utc(partial_kwargs["start_date"])
329 partial_kwargs["end_date"] = timezone.convert_to_utc(partial_kwargs["end_date"])
330 if partial_kwargs["pool"] is None:
331 partial_kwargs["pool"] = Pool.DEFAULT_POOL_NAME
332 partial_kwargs["retries"] = parse_retries(partial_kwargs["retries"])
333 partial_kwargs["retry_delay"] = coerce_timedelta(partial_kwargs["retry_delay"], key="retry_delay")
334 if partial_kwargs["max_retry_delay"] is not None:
335 partial_kwargs["max_retry_delay"] = coerce_timedelta(
336 partial_kwargs["max_retry_delay"],
337 key="max_retry_delay",
338 )
339 partial_kwargs["executor_config"] = partial_kwargs["executor_config"] or {}
340 partial_kwargs["resources"] = coerce_resources(partial_kwargs["resources"])
342 return OperatorPartial(
343 operator_class=operator_class,
344 kwargs=partial_kwargs,
345 params=partial_params,
346 )
349class BaseOperatorMeta(abc.ABCMeta):
350 """Metaclass of BaseOperator."""
352 @classmethod
353 def _apply_defaults(cls, func: T) -> T:
354 """
355 Function decorator that Looks for an argument named "default_args", and
356 fills the unspecified arguments from it.
358 Since python2.* isn't clear about which arguments are missing when
359 calling a function, and that this can be quite confusing with multi-level
360 inheritance and argument defaults, this decorator also alerts with
361 specific information about the missing arguments.
362 """
363 # Cache inspect.signature for the wrapper closure to avoid calling it
364 # at every decorated invocation. This is separate sig_cache created
365 # per decoration, i.e. each function decorated using apply_defaults will
366 # have a different sig_cache.
367 sig_cache = signature(func)
368 non_variadic_params = {
369 name: param
370 for (name, param) in sig_cache.parameters.items()
371 if param.name != "self" and param.kind not in (param.VAR_POSITIONAL, param.VAR_KEYWORD)
372 }
373 non_optional_args = {
374 name
375 for name, param in non_variadic_params.items()
376 if param.default == param.empty and name != "task_id"
377 }
379 fixup_decorator_warning_stack(func)
381 @functools.wraps(func)
382 def apply_defaults(self: BaseOperator, *args: Any, **kwargs: Any) -> Any:
383 from airflow.models.dag import DagContext
384 from airflow.utils.task_group import TaskGroupContext
386 if len(args) > 0:
387 raise AirflowException("Use keyword arguments when initializing operators")
389 instantiated_from_mapped = kwargs.pop(
390 "_airflow_from_mapped",
391 getattr(self, "_BaseOperator__from_mapped", False),
392 )
394 dag: DAG | None = kwargs.get("dag") or DagContext.get_current_dag()
395 task_group: TaskGroup | None = kwargs.get("task_group")
396 if dag and not task_group:
397 task_group = TaskGroupContext.get_current_task_group(dag)
399 default_args, merged_params = get_merged_defaults(
400 dag=dag,
401 task_group=task_group,
402 task_params=kwargs.pop("params", None),
403 task_default_args=kwargs.pop("default_args", None),
404 )
406 for arg in sig_cache.parameters:
407 if arg not in kwargs and arg in default_args:
408 kwargs[arg] = default_args[arg]
410 missing_args = non_optional_args - set(kwargs)
411 if len(missing_args) == 1:
412 raise AirflowException(f"missing keyword argument {missing_args.pop()!r}")
413 elif missing_args:
414 display = ", ".join(repr(a) for a in sorted(missing_args))
415 raise AirflowException(f"missing keyword arguments {display}")
417 if merged_params:
418 kwargs["params"] = merged_params
420 hook = getattr(self, "_hook_apply_defaults", None)
421 if hook:
422 args, kwargs = hook(**kwargs, default_args=default_args)
423 default_args = kwargs.pop("default_args", {})
425 if not hasattr(self, "_BaseOperator__init_kwargs"):
426 self._BaseOperator__init_kwargs = {}
427 self._BaseOperator__from_mapped = instantiated_from_mapped
429 result = func(self, **kwargs, default_args=default_args)
431 # Store the args passed to init -- we need them to support task.map serialzation!
432 self._BaseOperator__init_kwargs.update(kwargs) # type: ignore
434 # Set upstream task defined by XComArgs passed to template fields of the operator.
435 # BUT: only do this _ONCE_, not once for each class in the hierarchy
436 if not instantiated_from_mapped and func == self.__init__.__wrapped__: # type: ignore[misc]
437 self.set_xcomargs_dependencies()
438 # Mark instance as instantiated.
439 self._BaseOperator__instantiated = True
441 return result
443 apply_defaults.__non_optional_args = non_optional_args # type: ignore
444 apply_defaults.__param_names = set(non_variadic_params) # type: ignore
446 return cast(T, apply_defaults)
448 def __new__(cls, name, bases, namespace, **kwargs):
449 new_cls = super().__new__(cls, name, bases, namespace, **kwargs)
450 with contextlib.suppress(KeyError):
451 # Update the partial descriptor with the class method, so it calls the actual function
452 # (but let subclasses override it if they need to)
453 partial_desc = vars(new_cls)["partial"]
454 if isinstance(partial_desc, _PartialDescriptor):
455 partial_desc.class_method = classmethod(partial)
456 new_cls.__init__ = cls._apply_defaults(new_cls.__init__)
457 return new_cls
460@functools.total_ordering
461class BaseOperator(AbstractOperator, metaclass=BaseOperatorMeta):
462 """
463 Abstract base class for all operators. Since operators create objects that
464 become nodes in the dag, BaseOperator contains many recursive methods for
465 dag crawling behavior. To derive this class, you are expected to override
466 the constructor as well as the 'execute' method.
468 Operators derived from this class should perform or trigger certain tasks
469 synchronously (wait for completion). Example of operators could be an
470 operator that runs a Pig job (PigOperator), a sensor operator that
471 waits for a partition to land in Hive (HiveSensorOperator), or one that
472 moves data from Hive to MySQL (Hive2MySqlOperator). Instances of these
473 operators (tasks) target specific operations, running specific scripts,
474 functions or data transfers.
476 This class is abstract and shouldn't be instantiated. Instantiating a
477 class derived from this one results in the creation of a task object,
478 which ultimately becomes a node in DAG objects. Task dependencies should
479 be set by using the set_upstream and/or set_downstream methods.
481 :param task_id: a unique, meaningful id for the task
482 :param owner: the owner of the task. Using a meaningful description
483 (e.g. user/person/team/role name) to clarify ownership is recommended.
484 :param email: the 'to' email address(es) used in email alerts. This can be a
485 single email or multiple ones. Multiple addresses can be specified as a
486 comma or semicolon separated string or by passing a list of strings.
487 :param email_on_retry: Indicates whether email alerts should be sent when a
488 task is retried
489 :param email_on_failure: Indicates whether email alerts should be sent when
490 a task failed
491 :param retries: the number of retries that should be performed before
492 failing the task
493 :param retry_delay: delay between retries, can be set as ``timedelta`` or
494 ``float`` seconds, which will be converted into ``timedelta``,
495 the default is ``timedelta(seconds=300)``.
496 :param retry_exponential_backoff: allow progressively longer waits between
497 retries by using exponential backoff algorithm on retry delay (delay
498 will be converted into seconds)
499 :param max_retry_delay: maximum delay interval between retries, can be set as
500 ``timedelta`` or ``float`` seconds, which will be converted into ``timedelta``.
501 :param start_date: The ``start_date`` for the task, determines
502 the ``execution_date`` for the first task instance. The best practice
503 is to have the start_date rounded
504 to your DAG's ``schedule_interval``. Daily jobs have their start_date
505 some day at 00:00:00, hourly jobs have their start_date at 00:00
506 of a specific hour. Note that Airflow simply looks at the latest
507 ``execution_date`` and adds the ``schedule_interval`` to determine
508 the next ``execution_date``. It is also very important
509 to note that different tasks' dependencies
510 need to line up in time. If task A depends on task B and their
511 start_date are offset in a way that their execution_date don't line
512 up, A's dependencies will never be met. If you are looking to delay
513 a task, for example running a daily task at 2AM, look into the
514 ``TimeSensor`` and ``TimeDeltaSensor``. We advise against using
515 dynamic ``start_date`` and recommend using fixed ones. Read the
516 FAQ entry about start_date for more information.
517 :param end_date: if specified, the scheduler won't go beyond this date
518 :param depends_on_past: when set to true, task instances will run
519 sequentially and only if the previous instance has succeeded or has been skipped.
520 The task instance for the start_date is allowed to run.
521 :param wait_for_past_depends_before_skipping: when set to true, if the task instance
522 should be marked as skipped, and depends_on_past is true, the ti will stay on None state
523 waiting the task of the previous run
524 :param wait_for_downstream: when set to true, an instance of task
525 X will wait for tasks immediately downstream of the previous instance
526 of task X to finish successfully or be skipped before it runs. This is useful if the
527 different instances of a task X alter the same asset, and this asset
528 is used by tasks downstream of task X. Note that depends_on_past
529 is forced to True wherever wait_for_downstream is used. Also note that
530 only tasks *immediately* downstream of the previous task instance are waited
531 for; the statuses of any tasks further downstream are ignored.
532 :param dag: a reference to the dag the task is attached to (if any)
533 :param priority_weight: priority weight of this task against other task.
534 This allows the executor to trigger higher priority tasks before
535 others when things get backed up. Set priority_weight as a higher
536 number for more important tasks.
537 :param weight_rule: weighting method used for the effective total
538 priority weight of the task. Options are:
539 ``{ downstream | upstream | absolute }`` default is ``downstream``
540 When set to ``downstream`` the effective weight of the task is the
541 aggregate sum of all downstream descendants. As a result, upstream
542 tasks will have higher weight and will be scheduled more aggressively
543 when using positive weight values. This is useful when you have
544 multiple dag run instances and desire to have all upstream tasks to
545 complete for all runs before each dag can continue processing
546 downstream tasks. When set to ``upstream`` the effective weight is the
547 aggregate sum of all upstream ancestors. This is the opposite where
548 downstream tasks have higher weight and will be scheduled more
549 aggressively when using positive weight values. This is useful when you
550 have multiple dag run instances and prefer to have each dag complete
551 before starting upstream tasks of other dags. When set to
552 ``absolute``, the effective weight is the exact ``priority_weight``
553 specified without additional weighting. You may want to do this when
554 you know exactly what priority weight each task should have.
555 Additionally, when set to ``absolute``, there is bonus effect of
556 significantly speeding up the task creation process as for very large
557 DAGs. Options can be set as string or using the constants defined in
558 the static class ``airflow.utils.WeightRule``
559 :param queue: which queue to target when running this job. Not
560 all executors implement queue management, the CeleryExecutor
561 does support targeting specific queues.
562 :param pool: the slot pool this task should run in, slot pools are a
563 way to limit concurrency for certain tasks
564 :param pool_slots: the number of pool slots this task should use (>= 1)
565 Values less than 1 are not allowed.
566 :param sla: time by which the job is expected to succeed. Note that
567 this represents the ``timedelta`` after the period is closed. For
568 example if you set an SLA of 1 hour, the scheduler would send an email
569 soon after 1:00AM on the ``2016-01-02`` if the ``2016-01-01`` instance
570 has not succeeded yet.
571 The scheduler pays special attention for jobs with an SLA and
572 sends alert
573 emails for SLA misses. SLA misses are also recorded in the database
574 for future reference. All tasks that share the same SLA time
575 get bundled in a single email, sent soon after that time. SLA
576 notification are sent once and only once for each task instance.
577 :param execution_timeout: max time allowed for the execution of
578 this task instance, if it goes beyond it will raise and fail.
579 :param on_failure_callback: a function or list of functions to be called when a task instance
580 of this task fails. a context dictionary is passed as a single
581 parameter to this function. Context contains references to related
582 objects to the task instance and is documented under the macros
583 section of the API.
584 :param on_execute_callback: much like the ``on_failure_callback`` except
585 that it is executed right before the task is executed.
586 :param on_retry_callback: much like the ``on_failure_callback`` except
587 that it is executed when retries occur.
588 :param on_success_callback: much like the ``on_failure_callback`` except
589 that it is executed when the task succeeds.
590 :param pre_execute: a function to be called immediately before task
591 execution, receiving a context dictionary; raising an exception will
592 prevent the task from being executed.
594 |experimental|
595 :param post_execute: a function to be called immediately after task
596 execution, receiving a context dictionary and task result; raising an
597 exception will prevent the task from succeeding.
599 |experimental|
600 :param trigger_rule: defines the rule by which dependencies are applied
601 for the task to get triggered. Options are:
602 ``{ all_success | all_failed | all_done | all_skipped | one_success | one_done |
603 one_failed | none_failed | none_failed_min_one_success | none_skipped | always}``
604 default is ``all_success``. Options can be set as string or
605 using the constants defined in the static class
606 ``airflow.utils.TriggerRule``
607 :param resources: A map of resource parameter names (the argument names of the
608 Resources constructor) to their values.
609 :param run_as_user: unix username to impersonate while running the task
610 :param max_active_tis_per_dag: When set, a task will be able to limit the concurrent
611 runs across execution_dates.
612 :param max_active_tis_per_dagrun: When set, a task will be able to limit the concurrent
613 task instances per DAG run.
614 :param executor_config: Additional task-level configuration parameters that are
615 interpreted by a specific executor. Parameters are namespaced by the name of
616 executor.
618 **Example**: to run this task in a specific docker container through
619 the KubernetesExecutor ::
621 MyOperator(...,
622 executor_config={
623 "KubernetesExecutor":
624 {"image": "myCustomDockerImage"}
625 }
626 )
628 :param do_xcom_push: if True, an XCom is pushed containing the Operator's
629 result
630 :param task_group: The TaskGroup to which the task should belong. This is typically provided when not
631 using a TaskGroup as a context manager.
632 :param doc: Add documentation or notes to your Task objects that is visible in
633 Task Instance details View in the Webserver
634 :param doc_md: Add documentation (in Markdown format) or notes to your Task objects
635 that is visible in Task Instance details View in the Webserver
636 :param doc_rst: Add documentation (in RST format) or notes to your Task objects
637 that is visible in Task Instance details View in the Webserver
638 :param doc_json: Add documentation (in JSON format) or notes to your Task objects
639 that is visible in Task Instance details View in the Webserver
640 :param doc_yaml: Add documentation (in YAML format) or notes to your Task objects
641 that is visible in Task Instance details View in the Webserver
642 """
644 # Implementing Operator.
645 template_fields: Sequence[str] = ()
646 template_ext: Sequence[str] = ()
648 template_fields_renderers: dict[str, str] = {}
650 # Defines the color in the UI
651 ui_color: str = "#fff"
652 ui_fgcolor: str = "#000"
654 pool: str = ""
656 # base list which includes all the attrs that don't need deep copy.
657 _base_operator_shallow_copy_attrs: tuple[str, ...] = (
658 "user_defined_macros",
659 "user_defined_filters",
660 "params",
661 "_log",
662 )
664 # each operator should override this class attr for shallow copy attrs.
665 shallow_copy_attrs: Sequence[str] = ()
667 # Defines the operator level extra links
668 operator_extra_links: Collection[BaseOperatorLink] = ()
670 # The _serialized_fields are lazily loaded when get_serialized_fields() method is called
671 __serialized_fields: frozenset[str] | None = None
673 partial: Callable[..., OperatorPartial] = _PartialDescriptor() # type: ignore
675 _comps = {
676 "task_id",
677 "dag_id",
678 "owner",
679 "email",
680 "email_on_retry",
681 "retry_delay",
682 "retry_exponential_backoff",
683 "max_retry_delay",
684 "start_date",
685 "end_date",
686 "depends_on_past",
687 "wait_for_downstream",
688 "priority_weight",
689 "sla",
690 "execution_timeout",
691 "on_execute_callback",
692 "on_failure_callback",
693 "on_success_callback",
694 "on_retry_callback",
695 "do_xcom_push",
696 }
698 # Defines if the operator supports lineage without manual definitions
699 supports_lineage = False
701 # If True then the class constructor was called
702 __instantiated = False
703 # List of args as passed to `init()`, after apply_defaults() has been updated. Used to "recreate" the task
704 # when mapping
705 __init_kwargs: dict[str, Any]
707 # Set to True before calling execute method
708 _lock_for_execution = False
710 _dag: DAG | None = None
711 task_group: TaskGroup | None = None
713 # subdag parameter is only set for SubDagOperator.
714 # Setting it to None by default as other Operators do not have that field
715 subdag: DAG | None = None
717 start_date: pendulum.DateTime | None = None
718 end_date: pendulum.DateTime | None = None
720 # Set to True for an operator instantiated by a mapped operator.
721 __from_mapped = False
723 is_setup = False
724 """
725 Whether the operator is a setup task
727 :meta private:
728 """
729 is_teardown = False
730 """
731 Whether the operator is a teardown task
733 :meta private:
734 """
735 on_failure_fail_dagrun = False
736 """
737 Whether the operator should fail the dagrun on failure
739 :meta private:
740 """
742 def __init__(
743 self,
744 task_id: str,
745 owner: str = DEFAULT_OWNER,
746 email: str | Iterable[str] | None = None,
747 email_on_retry: bool = conf.getboolean("email", "default_email_on_retry", fallback=True),
748 email_on_failure: bool = conf.getboolean("email", "default_email_on_failure", fallback=True),
749 retries: int | None = DEFAULT_RETRIES,
750 retry_delay: timedelta | float = DEFAULT_RETRY_DELAY,
751 retry_exponential_backoff: bool = False,
752 max_retry_delay: timedelta | float | None = None,
753 start_date: datetime | None = None,
754 end_date: datetime | None = None,
755 depends_on_past: bool = False,
756 ignore_first_depends_on_past: bool = DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST,
757 wait_for_past_depends_before_skipping: bool = DEFAULT_WAIT_FOR_PAST_DEPENDS_BEFORE_SKIPPING,
758 wait_for_downstream: bool = False,
759 dag: DAG | None = None,
760 params: collections.abc.MutableMapping | None = None,
761 default_args: dict | None = None,
762 priority_weight: int = DEFAULT_PRIORITY_WEIGHT,
763 weight_rule: str = DEFAULT_WEIGHT_RULE,
764 queue: str = DEFAULT_QUEUE,
765 pool: str | None = None,
766 pool_slots: int = DEFAULT_POOL_SLOTS,
767 sla: timedelta | None = None,
768 execution_timeout: timedelta | None = DEFAULT_TASK_EXECUTION_TIMEOUT,
769 on_execute_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = None,
770 on_failure_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = None,
771 on_success_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = None,
772 on_retry_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = None,
773 pre_execute: TaskPreExecuteHook | None = None,
774 post_execute: TaskPostExecuteHook | None = None,
775 trigger_rule: str = DEFAULT_TRIGGER_RULE,
776 resources: dict[str, Any] | None = None,
777 run_as_user: str | None = None,
778 task_concurrency: int | None = None,
779 max_active_tis_per_dag: int | None = None,
780 max_active_tis_per_dagrun: int | None = None,
781 executor_config: dict | None = None,
782 do_xcom_push: bool = True,
783 inlets: Any | None = None,
784 outlets: Any | None = None,
785 task_group: TaskGroup | None = None,
786 doc: str | None = None,
787 doc_md: str | None = None,
788 doc_json: str | None = None,
789 doc_yaml: str | None = None,
790 doc_rst: str | None = None,
791 **kwargs,
792 ):
793 from airflow.models.dag import DagContext
794 from airflow.utils.task_group import TaskGroupContext
796 self.__init_kwargs = {}
798 super().__init__()
800 kwargs.pop("_airflow_mapped_validation_only", None)
801 if kwargs:
802 if not conf.getboolean("operators", "ALLOW_ILLEGAL_ARGUMENTS"):
803 raise AirflowException(
804 f"Invalid arguments were passed to {self.__class__.__name__} (task_id: {task_id}). "
805 f"Invalid arguments were:\n**kwargs: {kwargs}",
806 )
807 warnings.warn(
808 f"Invalid arguments were passed to {self.__class__.__name__} (task_id: {task_id}). "
809 "Support for passing such arguments will be dropped in future. "
810 f"Invalid arguments were:\n**kwargs: {kwargs}",
811 category=RemovedInAirflow3Warning,
812 stacklevel=3,
813 )
814 validate_key(task_id)
816 dag = dag or DagContext.get_current_dag()
817 task_group = task_group or TaskGroupContext.get_current_task_group(dag)
819 DagInvalidTriggerRule.check(dag, trigger_rule)
821 self.task_id = task_group.child_id(task_id) if task_group else task_id
822 if not self.__from_mapped and task_group:
823 task_group.add(self)
825 self.owner = owner
826 self.email = email
827 self.email_on_retry = email_on_retry
828 self.email_on_failure = email_on_failure
830 if execution_timeout is not None and not isinstance(execution_timeout, timedelta):
831 raise ValueError(
832 f"execution_timeout must be timedelta object but passed as type: {type(execution_timeout)}"
833 )
834 self.execution_timeout = execution_timeout
836 self.on_execute_callback = on_execute_callback
837 self.on_failure_callback = on_failure_callback
838 self.on_success_callback = on_success_callback
839 self.on_retry_callback = on_retry_callback
840 self._pre_execute_hook = pre_execute
841 self._post_execute_hook = post_execute
843 if start_date and not isinstance(start_date, datetime):
844 self.log.warning("start_date for %s isn't datetime.datetime", self)
845 elif start_date:
846 self.start_date = timezone.convert_to_utc(start_date)
848 if end_date:
849 self.end_date = timezone.convert_to_utc(end_date)
851 self.executor_config = executor_config or {}
852 self.run_as_user = run_as_user
853 self.retries = parse_retries(retries)
854 self.queue = queue
855 self.pool = Pool.DEFAULT_POOL_NAME if pool is None else pool
856 self.pool_slots = pool_slots
857 if self.pool_slots < 1:
858 dag_str = f" in dag {dag.dag_id}" if dag else ""
859 raise ValueError(f"pool slots for {self.task_id}{dag_str} cannot be less than 1")
860 self.sla = sla
862 if trigger_rule == "dummy":
863 warnings.warn(
864 "dummy Trigger Rule is deprecated. Please use `TriggerRule.ALWAYS`.",
865 RemovedInAirflow3Warning,
866 stacklevel=2,
867 )
868 trigger_rule = TriggerRule.ALWAYS
870 if trigger_rule == "none_failed_or_skipped":
871 warnings.warn(
872 "none_failed_or_skipped Trigger Rule is deprecated. "
873 "Please use `none_failed_min_one_success`.",
874 RemovedInAirflow3Warning,
875 stacklevel=2,
876 )
877 trigger_rule = TriggerRule.NONE_FAILED_MIN_ONE_SUCCESS
879 if not TriggerRule.is_valid(trigger_rule):
880 raise AirflowException(
881 f"The trigger_rule must be one of {TriggerRule.all_triggers()},"
882 f"'{dag.dag_id if dag else ''}.{task_id}'; received '{trigger_rule}'."
883 )
885 self.trigger_rule: TriggerRule = TriggerRule(trigger_rule)
886 self.depends_on_past: bool = depends_on_past
887 self.ignore_first_depends_on_past: bool = ignore_first_depends_on_past
888 self.wait_for_past_depends_before_skipping: bool = wait_for_past_depends_before_skipping
889 self.wait_for_downstream: bool = wait_for_downstream
890 if wait_for_downstream:
891 self.depends_on_past = True
893 self.retry_delay = coerce_timedelta(retry_delay, key="retry_delay")
894 self.retry_exponential_backoff = retry_exponential_backoff
895 self.max_retry_delay = (
896 max_retry_delay
897 if max_retry_delay is None
898 else coerce_timedelta(max_retry_delay, key="max_retry_delay")
899 )
901 # At execution_time this becomes a normal dict
902 self.params: ParamsDict | dict = ParamsDict(params)
903 if priority_weight is not None and not isinstance(priority_weight, int):
904 raise AirflowException(
905 f"`priority_weight` for task '{self.task_id}' only accepts integers, "
906 f"received '{type(priority_weight)}'."
907 )
908 self.priority_weight = priority_weight
909 if not WeightRule.is_valid(weight_rule):
910 raise AirflowException(
911 f"The weight_rule must be one of "
912 f"{WeightRule.all_weight_rules},'{dag.dag_id if dag else ''}.{task_id}'; "
913 f"received '{weight_rule}'."
914 )
915 self.weight_rule = weight_rule
916 self.resources = coerce_resources(resources)
917 if task_concurrency and not max_active_tis_per_dag:
918 # TODO: Remove in Airflow 3.0
919 warnings.warn(
920 "The 'task_concurrency' parameter is deprecated. Please use 'max_active_tis_per_dag'.",
921 RemovedInAirflow3Warning,
922 stacklevel=2,
923 )
924 max_active_tis_per_dag = task_concurrency
925 self.max_active_tis_per_dag: int | None = max_active_tis_per_dag
926 self.max_active_tis_per_dagrun: int | None = max_active_tis_per_dagrun
927 self.do_xcom_push = do_xcom_push
929 self.doc_md = doc_md
930 self.doc_json = doc_json
931 self.doc_yaml = doc_yaml
932 self.doc_rst = doc_rst
933 self.doc = doc
935 self.upstream_task_ids: set[str] = set()
936 self.downstream_task_ids: set[str] = set()
938 if dag:
939 self.dag = dag
941 self._log = logging.getLogger("airflow.task.operators")
943 # Lineage
944 self.inlets: list = []
945 self.outlets: list = []
947 if inlets:
948 self.inlets = (
949 inlets
950 if isinstance(inlets, list)
951 else [
952 inlets,
953 ]
954 )
956 if outlets:
957 self.outlets = (
958 outlets
959 if isinstance(outlets, list)
960 else [
961 outlets,
962 ]
963 )
965 if isinstance(self.template_fields, str):
966 warnings.warn(
967 f"The `template_fields` value for {self.task_type} is a string "
968 "but should be a list or tuple of string. Wrapping it in a list for execution. "
969 f"Please update {self.task_type} accordingly.",
970 UserWarning,
971 stacklevel=2,
972 )
973 self.template_fields = [self.template_fields]
975 if SetupTeardownContext.active:
976 SetupTeardownContext.update_context_map(self)
978 @classmethod
979 def as_setup(cls, *args, **kwargs):
980 op = cls(*args, **kwargs)
981 op.is_setup = True
982 return op
984 @classmethod
985 def as_teardown(cls, *args, **kwargs):
986 on_failure_fail_dagrun = kwargs.pop("on_failure_fail_dagrun", False)
987 if "trigger_rule" in kwargs:
988 raise ValueError("Cannot set trigger rule for teardown tasks.")
989 op = cls(*args, **kwargs, trigger_rule=TriggerRule.ALL_DONE_SETUP_SUCCESS)
990 op.is_teardown = True
991 op.on_failure_fail_dagrun = on_failure_fail_dagrun
992 return op
994 def __enter__(self):
995 if not self.is_setup and not self.is_teardown:
996 raise AirflowException("Only setup/teardown tasks can be used as context managers.")
997 SetupTeardownContext.push_setup_teardown_task(self)
998 return self
1000 def __exit__(self, exc_type, exc_val, exc_tb):
1001 SetupTeardownContext.set_work_task_roots_and_leaves()
1003 def __eq__(self, other):
1004 if type(self) is type(other):
1005 # Use getattr() instead of __dict__ as __dict__ doesn't return
1006 # correct values for properties.
1007 return all(getattr(self, c, None) == getattr(other, c, None) for c in self._comps)
1008 return False
1010 def __ne__(self, other):
1011 return not self == other
1013 def __hash__(self):
1014 hash_components = [type(self)]
1015 for component in self._comps:
1016 val = getattr(self, component, None)
1017 try:
1018 hash(val)
1019 hash_components.append(val)
1020 except TypeError:
1021 hash_components.append(repr(val))
1022 return hash(tuple(hash_components))
1024 # including lineage information
1025 def __or__(self, other):
1026 """
1027 Called for [This Operator] | [Operator], The inlets of other
1028 will be set to pick up the outlets from this operator. Other will
1029 be set as a downstream task of this operator.
1030 """
1031 if isinstance(other, BaseOperator):
1032 if not self.outlets and not self.supports_lineage:
1033 raise ValueError("No outlets defined for this operator")
1034 other.add_inlets([self.task_id])
1035 self.set_downstream(other)
1036 else:
1037 raise TypeError(f"Right hand side ({other}) is not an Operator")
1039 return self
1041 # /Composing Operators ---------------------------------------------
1043 def __gt__(self, other):
1044 """
1045 Called for [Operator] > [Outlet], so that if other is an attr annotated object
1046 it is set as an outlet of this Operator.
1047 """
1048 if not isinstance(other, Iterable):
1049 other = [other]
1051 for obj in other:
1052 if not attr.has(obj):
1053 raise TypeError(f"Left hand side ({obj}) is not an outlet")
1054 self.add_outlets(other)
1056 return self
1058 def __lt__(self, other):
1059 """
1060 Called for [Inlet] > [Operator] or [Operator] < [Inlet], so that if other is
1061 an attr annotated object it is set as an inlet to this operator.
1062 """
1063 if not isinstance(other, Iterable):
1064 other = [other]
1066 for obj in other:
1067 if not attr.has(obj):
1068 raise TypeError(f"{obj} cannot be an inlet")
1069 self.add_inlets(other)
1071 return self
1073 def __setattr__(self, key, value):
1074 super().__setattr__(key, value)
1075 if self.__from_mapped or self._lock_for_execution:
1076 return # Skip any custom behavior for validation and during execute.
1077 if key in self.__init_kwargs:
1078 self.__init_kwargs[key] = value
1079 if self.__instantiated and key in self.template_fields:
1080 # Resolve upstreams set by assigning an XComArg after initializing
1081 # an operator, example:
1082 # op = BashOperator()
1083 # op.bash_command = "sleep 1"
1084 self.set_xcomargs_dependencies()
1086 def add_inlets(self, inlets: Iterable[Any]):
1087 """Sets inlets to this operator."""
1088 self.inlets.extend(inlets)
1090 def add_outlets(self, outlets: Iterable[Any]):
1091 """Defines the outlets of this operator."""
1092 self.outlets.extend(outlets)
1094 def get_inlet_defs(self):
1095 """Gets inlet definitions on this task.
1097 :meta private:
1098 """
1099 return self.inlets
1101 def get_outlet_defs(self):
1102 """Gets outlet definitions on this task.
1104 :meta private:
1105 """
1106 return self.outlets
1108 def get_dag(self) -> DAG | None:
1109 return self._dag
1111 @property # type: ignore[override]
1112 def dag(self) -> DAG: # type: ignore[override]
1113 """Returns the Operator's DAG if set, otherwise raises an error."""
1114 if self._dag:
1115 return self._dag
1116 else:
1117 raise AirflowException(f"Operator {self} has not been assigned to a DAG yet")
1119 @dag.setter
1120 def dag(self, dag: DAG | None):
1121 """
1122 Operators can be assigned to one DAG, one time. Repeat assignments to
1123 that same DAG are ok.
1124 """
1125 from airflow.models.dag import DAG
1127 if dag is None:
1128 self._dag = None
1129 return
1130 if not isinstance(dag, DAG):
1131 raise TypeError(f"Expected DAG; received {dag.__class__.__name__}")
1132 elif self.has_dag() and self.dag is not dag:
1133 raise AirflowException(f"The DAG assigned to {self} can not be changed.")
1135 if self.__from_mapped:
1136 pass # Don't add to DAG -- the mapped task takes the place.
1137 elif self.task_id not in dag.task_dict:
1138 dag.add_task(self)
1139 elif self.task_id in dag.task_dict and dag.task_dict[self.task_id] is not self:
1140 dag.add_task(self)
1142 self._dag = dag
1144 def has_dag(self):
1145 """Returns True if the Operator has been assigned to a DAG."""
1146 return self._dag is not None
1148 deps: frozenset[BaseTIDep] = frozenset(
1149 {
1150 NotInRetryPeriodDep(),
1151 PrevDagrunDep(),
1152 TriggerRuleDep(),
1153 NotPreviouslySkippedDep(),
1154 }
1155 )
1156 """
1157 Returns the set of dependencies for the operator. These differ from execution
1158 context dependencies in that they are specific to tasks and can be
1159 extended/overridden by subclasses.
1160 """
1162 def prepare_for_execution(self) -> BaseOperator:
1163 """
1164 Lock task for execution to disable custom action in __setattr__ and
1165 returns a copy of the task.
1166 """
1167 other = copy.copy(self)
1168 other._lock_for_execution = True
1169 return other
1171 def set_xcomargs_dependencies(self) -> None:
1172 """
1173 Resolves upstream dependencies of a task. In this way passing an ``XComArg``
1174 as value for a template field will result in creating upstream relation between
1175 two tasks.
1177 **Example**: ::
1179 with DAG(...):
1180 generate_content = GenerateContentOperator(task_id="generate_content")
1181 send_email = EmailOperator(..., html_content=generate_content.output)
1183 # This is equivalent to
1184 with DAG(...):
1185 generate_content = GenerateContentOperator(task_id="generate_content")
1186 send_email = EmailOperator(
1187 ..., html_content="{{ task_instance.xcom_pull('generate_content') }}"
1188 )
1189 generate_content >> send_email
1191 """
1192 from airflow.models.xcom_arg import XComArg
1194 for field in self.template_fields:
1195 if hasattr(self, field):
1196 arg = getattr(self, field)
1197 XComArg.apply_upstream_relationship(self, arg)
1199 @prepare_lineage
1200 def pre_execute(self, context: Any):
1201 """This hook is triggered right before self.execute() is called."""
1202 if self._pre_execute_hook is not None:
1203 self._pre_execute_hook(context)
1205 def execute(self, context: Context) -> Any:
1206 """
1207 This is the main method to derive when creating an operator.
1208 Context is the same dictionary used as when rendering jinja templates.
1210 Refer to get_template_context for more context.
1211 """
1212 raise NotImplementedError()
1214 @apply_lineage
1215 def post_execute(self, context: Any, result: Any = None):
1216 """
1217 This hook is triggered right after self.execute() is called.
1218 It is passed the execution context and any results returned by the
1219 operator.
1220 """
1221 if self._post_execute_hook is not None:
1222 self._post_execute_hook(context, result)
1224 def on_kill(self) -> None:
1225 """
1226 Override this method to clean up subprocesses when a task instance
1227 gets killed. Any use of the threading, subprocess or multiprocessing
1228 module within an operator needs to be cleaned up, or it will leave
1229 ghost processes behind.
1230 """
1232 def __deepcopy__(self, memo):
1233 # Hack sorting double chained task lists by task_id to avoid hitting
1234 # max_depth on deepcopy operations.
1235 sys.setrecursionlimit(5000) # TODO fix this in a better way
1237 cls = self.__class__
1238 result = cls.__new__(cls)
1239 memo[id(self)] = result
1241 shallow_copy = cls.shallow_copy_attrs + cls._base_operator_shallow_copy_attrs
1243 for k, v in self.__dict__.items():
1244 if k == "_BaseOperator__instantiated":
1245 # Don't set this until the _end_, as it changes behaviour of __setattr__
1246 continue
1247 if k not in shallow_copy:
1248 setattr(result, k, copy.deepcopy(v, memo))
1249 else:
1250 setattr(result, k, copy.copy(v))
1251 result.__instantiated = self.__instantiated
1252 return result
1254 def __getstate__(self):
1255 state = dict(self.__dict__)
1256 del state["_log"]
1258 return state
1260 def __setstate__(self, state):
1261 self.__dict__ = state
1262 self._log = logging.getLogger("airflow.task.operators")
1264 def render_template_fields(
1265 self,
1266 context: Context,
1267 jinja_env: jinja2.Environment | None = None,
1268 ) -> None:
1269 """Template all attributes listed in *self.template_fields*.
1271 This mutates the attributes in-place and is irreversible.
1273 :param context: Context dict with values to apply on content.
1274 :param jinja_env: Jinja's environment to use for rendering.
1275 """
1276 if not jinja_env:
1277 jinja_env = self.get_template_env()
1278 self._do_render_template_fields(self, self.template_fields, context, jinja_env, set())
1280 @provide_session
1281 def clear(
1282 self,
1283 start_date: datetime | None = None,
1284 end_date: datetime | None = None,
1285 upstream: bool = False,
1286 downstream: bool = False,
1287 session: Session = NEW_SESSION,
1288 ):
1289 """
1290 Clears the state of task instances associated with the task, following
1291 the parameters specified.
1292 """
1293 qry = session.query(TaskInstance).filter(TaskInstance.dag_id == self.dag_id)
1295 if start_date:
1296 qry = qry.filter(TaskInstance.execution_date >= start_date)
1297 if end_date:
1298 qry = qry.filter(TaskInstance.execution_date <= end_date)
1300 tasks = [self.task_id]
1302 if upstream:
1303 tasks += [t.task_id for t in self.get_flat_relatives(upstream=True)]
1305 if downstream:
1306 tasks += [t.task_id for t in self.get_flat_relatives(upstream=False)]
1308 qry = qry.filter(TaskInstance.task_id.in_(tasks))
1309 results = qry.all()
1310 count = len(results)
1311 clear_task_instances(results, session, dag=self.dag)
1312 session.commit()
1313 return count
1315 @provide_session
1316 def get_task_instances(
1317 self,
1318 start_date: datetime | None = None,
1319 end_date: datetime | None = None,
1320 session: Session = NEW_SESSION,
1321 ) -> list[TaskInstance]:
1322 """Get task instances related to this task for a specific date range."""
1323 from airflow.models import DagRun
1325 end_date = end_date or timezone.utcnow()
1326 return (
1327 session.query(TaskInstance)
1328 .join(TaskInstance.dag_run)
1329 .filter(TaskInstance.dag_id == self.dag_id)
1330 .filter(TaskInstance.task_id == self.task_id)
1331 .filter(DagRun.execution_date >= start_date)
1332 .filter(DagRun.execution_date <= end_date)
1333 .order_by(DagRun.execution_date)
1334 .all()
1335 )
1337 @provide_session
1338 def run(
1339 self,
1340 start_date: datetime | None = None,
1341 end_date: datetime | None = None,
1342 ignore_first_depends_on_past: bool = True,
1343 wait_for_past_depends_before_skipping: bool = False,
1344 ignore_ti_state: bool = False,
1345 mark_success: bool = False,
1346 test_mode: bool = False,
1347 session: Session = NEW_SESSION,
1348 ) -> None:
1349 """Run a set of task instances for a date range."""
1350 from airflow.models import DagRun
1351 from airflow.utils.types import DagRunType
1353 # Assertions for typing -- we need a dag, for this function, and when we have a DAG we are
1354 # _guaranteed_ to have start_date (else we couldn't have been added to a DAG)
1355 if TYPE_CHECKING:
1356 assert self.start_date
1358 start_date = pendulum.instance(start_date or self.start_date)
1359 end_date = pendulum.instance(end_date or self.end_date or timezone.utcnow())
1361 for info in self.dag.iter_dagrun_infos_between(start_date, end_date, align=False):
1362 ignore_depends_on_past = info.logical_date == start_date and ignore_first_depends_on_past
1363 try:
1364 dag_run = (
1365 session.query(DagRun)
1366 .filter(
1367 DagRun.dag_id == self.dag_id,
1368 DagRun.execution_date == info.logical_date,
1369 )
1370 .one()
1371 )
1372 ti = TaskInstance(self, run_id=dag_run.run_id)
1373 except NoResultFound:
1374 # This is _mostly_ only used in tests
1375 dr = DagRun(
1376 dag_id=self.dag_id,
1377 run_id=DagRun.generate_run_id(DagRunType.MANUAL, info.logical_date),
1378 run_type=DagRunType.MANUAL,
1379 execution_date=info.logical_date,
1380 data_interval=info.data_interval,
1381 )
1382 ti = TaskInstance(self, run_id=dr.run_id)
1383 ti.dag_run = dr
1384 session.add(dr)
1385 session.flush()
1387 ti.run(
1388 mark_success=mark_success,
1389 ignore_depends_on_past=ignore_depends_on_past,
1390 wait_for_past_depends_before_skipping=wait_for_past_depends_before_skipping,
1391 ignore_ti_state=ignore_ti_state,
1392 test_mode=test_mode,
1393 session=session,
1394 )
1396 def dry_run(self) -> None:
1397 """Performs dry run for the operator - just render template fields."""
1398 self.log.info("Dry run")
1399 for field in self.template_fields:
1400 try:
1401 content = getattr(self, field)
1402 except AttributeError:
1403 raise AttributeError(
1404 f"{field!r} is configured as a template field "
1405 f"but {self.task_type} does not have this attribute."
1406 )
1408 if content and isinstance(content, str):
1409 self.log.info("Rendering template for %s", field)
1410 self.log.info(content)
1412 def get_direct_relatives(self, upstream: bool = False) -> Iterable[DAGNode]:
1413 """
1414 Get list of the direct relatives to the current task, upstream or
1415 downstream.
1416 """
1417 if upstream:
1418 return self.upstream_list
1419 else:
1420 return self.downstream_list
1422 def __repr__(self):
1423 return "<Task({self.task_type}): {self.task_id}>".format(self=self)
1425 @property
1426 def operator_class(self) -> type[BaseOperator]: # type: ignore[override]
1427 return self.__class__
1429 @property
1430 def task_type(self) -> str:
1431 """@property: type of the task."""
1432 return self.__class__.__name__
1434 @property
1435 def operator_name(self) -> str:
1436 """@property: use a more friendly display name for the operator, if set."""
1437 try:
1438 return self.custom_operator_name # type: ignore
1439 except AttributeError:
1440 return self.task_type
1442 @property
1443 def roots(self) -> list[BaseOperator]:
1444 """Required by DAGNode."""
1445 return [self]
1447 @property
1448 def leaves(self) -> list[BaseOperator]:
1449 """Required by DAGNode."""
1450 return [self]
1452 @property
1453 def output(self) -> XComArg:
1454 """Returns reference to XCom pushed by current operator."""
1455 from airflow.models.xcom_arg import XComArg
1457 return XComArg(operator=self)
1459 @staticmethod
1460 def xcom_push(
1461 context: Any,
1462 key: str,
1463 value: Any,
1464 execution_date: datetime | None = None,
1465 ) -> None:
1466 """
1467 Make an XCom available for tasks to pull.
1469 :param context: Execution Context Dictionary
1470 :param key: A key for the XCom
1471 :param value: A value for the XCom. The value is pickled and stored
1472 in the database.
1473 :param execution_date: if provided, the XCom will not be visible until
1474 this date. This can be used, for example, to send a message to a
1475 task on a future date without it being immediately visible.
1476 """
1477 context["ti"].xcom_push(key=key, value=value, execution_date=execution_date)
1479 @staticmethod
1480 def xcom_pull(
1481 context: Any,
1482 task_ids: str | list[str] | None = None,
1483 dag_id: str | None = None,
1484 key: str = XCOM_RETURN_KEY,
1485 include_prior_dates: bool | None = None,
1486 ) -> Any:
1487 """
1488 Pull XComs that optionally meet certain criteria.
1490 The default value for `key` limits the search to XComs
1491 that were returned by other tasks (as opposed to those that were pushed
1492 manually). To remove this filter, pass key=None (or any desired value).
1494 If a single task_id string is provided, the result is the value of the
1495 most recent matching XCom from that task_id. If multiple task_ids are
1496 provided, a tuple of matching values is returned. None is returned
1497 whenever no matches are found.
1499 :param context: Execution Context Dictionary
1500 :param key: A key for the XCom. If provided, only XComs with matching
1501 keys will be returned. The default key is 'return_value', also
1502 available as a constant XCOM_RETURN_KEY. This key is automatically
1503 given to XComs returned by tasks (as opposed to being pushed
1504 manually). To remove the filter, pass key=None.
1505 :param task_ids: Only XComs from tasks with matching ids will be
1506 pulled. Can pass None to remove the filter.
1507 :param dag_id: If provided, only pulls XComs from this DAG.
1508 If None (default), the DAG of the calling task is used.
1509 :param include_prior_dates: If False, only XComs from the current
1510 execution_date are returned. If True, XComs from previous dates
1511 are returned as well.
1512 """
1513 return context["ti"].xcom_pull(
1514 key=key, task_ids=task_ids, dag_id=dag_id, include_prior_dates=include_prior_dates
1515 )
1517 @classmethod
1518 def get_serialized_fields(cls):
1519 """Stringified DAGs and operators contain exactly these fields."""
1520 if not cls.__serialized_fields:
1521 from airflow.models.dag import DagContext
1523 # make sure the following dummy task is not added to current active
1524 # dag in context, otherwise, it will result in
1525 # `RuntimeError: dictionary changed size during iteration`
1526 # Exception in SerializedDAG.serialize_dag() call.
1527 DagContext.push_context_managed_dag(None)
1528 cls.__serialized_fields = frozenset(
1529 vars(BaseOperator(task_id="test")).keys()
1530 - {
1531 "upstream_task_ids",
1532 "default_args",
1533 "dag",
1534 "_dag",
1535 "label",
1536 "_BaseOperator__instantiated",
1537 "_BaseOperator__init_kwargs",
1538 "_BaseOperator__from_mapped",
1539 }
1540 | { # Class level defaults need to be added to this list
1541 "start_date",
1542 "end_date",
1543 "_task_type",
1544 "_operator_name",
1545 "subdag",
1546 "ui_color",
1547 "ui_fgcolor",
1548 "template_ext",
1549 "template_fields",
1550 "template_fields_renderers",
1551 "params",
1552 "is_setup",
1553 "is_teardown",
1554 "on_failure_fail_dagrun",
1555 }
1556 )
1557 DagContext.pop_context_managed_dag()
1559 return cls.__serialized_fields
1561 def serialize_for_task_group(self) -> tuple[DagAttributeTypes, Any]:
1562 """Required by DAGNode."""
1563 return DagAttributeTypes.OP, self.task_id
1565 @property
1566 def inherits_from_empty_operator(self):
1567 """Used to determine if an Operator is inherited from EmptyOperator."""
1568 # This looks like `isinstance(self, EmptyOperator) would work, but this also
1569 # needs to cope when `self` is a Serialized instance of a EmptyOperator or one
1570 # of its subclasses (which don't inherit from anything but BaseOperator).
1571 return getattr(self, "_is_empty", False)
1573 def defer(
1574 self,
1575 *,
1576 trigger: BaseTrigger,
1577 method_name: str,
1578 kwargs: dict[str, Any] | None = None,
1579 timeout: timedelta | None = None,
1580 ):
1581 """
1582 Marks this Operator as being "deferred" - that is, suspending its
1583 execution until the provided trigger fires an event.
1585 This is achieved by raising a special exception (TaskDeferred)
1586 which is caught in the main _execute_task wrapper.
1587 """
1588 raise TaskDeferred(trigger=trigger, method_name=method_name, kwargs=kwargs, timeout=timeout)
1590 def unmap(self, resolve: None | dict[str, Any] | tuple[Context, Session]) -> BaseOperator:
1591 """Get the "normal" operator from the current operator.
1593 Since a BaseOperator is not mapped to begin with, this simply returns
1594 the original operator.
1596 :meta private:
1597 """
1598 return self
1601# TODO: Deprecate for Airflow 3.0
1602Chainable = Union[DependencyMixin, Sequence[DependencyMixin]]
1605def chain(*tasks: DependencyMixin | Sequence[DependencyMixin]) -> None:
1606 r"""
1607 Given a number of tasks, builds a dependency chain.
1609 This function accepts values of BaseOperator (aka tasks), EdgeModifiers (aka Labels), XComArg, TaskGroups,
1610 or lists containing any mix of these types (or a mix in the same list). If you want to chain between two
1611 lists you must ensure they have the same length.
1613 Using classic operators/sensors:
1615 .. code-block:: python
1617 chain(t1, [t2, t3], [t4, t5], t6)
1619 is equivalent to::
1621 / -> t2 -> t4 \
1622 t1 -> t6
1623 \ -> t3 -> t5 /
1625 .. code-block:: python
1627 t1.set_downstream(t2)
1628 t1.set_downstream(t3)
1629 t2.set_downstream(t4)
1630 t3.set_downstream(t5)
1631 t4.set_downstream(t6)
1632 t5.set_downstream(t6)
1634 Using task-decorated functions aka XComArgs:
1636 .. code-block:: python
1638 chain(x1(), [x2(), x3()], [x4(), x5()], x6())
1640 is equivalent to::
1642 / -> x2 -> x4 \
1643 x1 -> x6
1644 \ -> x3 -> x5 /
1646 .. code-block:: python
1648 x1 = x1()
1649 x2 = x2()
1650 x3 = x3()
1651 x4 = x4()
1652 x5 = x5()
1653 x6 = x6()
1654 x1.set_downstream(x2)
1655 x1.set_downstream(x3)
1656 x2.set_downstream(x4)
1657 x3.set_downstream(x5)
1658 x4.set_downstream(x6)
1659 x5.set_downstream(x6)
1661 Using TaskGroups:
1663 .. code-block:: python
1665 chain(t1, task_group1, task_group2, t2)
1667 t1.set_downstream(task_group1)
1668 task_group1.set_downstream(task_group2)
1669 task_group2.set_downstream(t2)
1672 It is also possible to mix between classic operator/sensor, EdgeModifiers, XComArg, and TaskGroups:
1674 .. code-block:: python
1676 chain(t1, [Label("branch one"), Label("branch two")], [x1(), x2()], task_group1, x3())
1678 is equivalent to::
1680 / "branch one" -> x1 \
1681 t1 -> task_group1 -> x3
1682 \ "branch two" -> x2 /
1684 .. code-block:: python
1686 x1 = x1()
1687 x2 = x2()
1688 x3 = x3()
1689 label1 = Label("branch one")
1690 label2 = Label("branch two")
1691 t1.set_downstream(label1)
1692 label1.set_downstream(x1)
1693 t2.set_downstream(label2)
1694 label2.set_downstream(x2)
1695 x1.set_downstream(task_group1)
1696 x2.set_downstream(task_group1)
1697 task_group1.set_downstream(x3)
1699 # or
1701 x1 = x1()
1702 x2 = x2()
1703 x3 = x3()
1704 t1.set_downstream(x1, edge_modifier=Label("branch one"))
1705 t1.set_downstream(x2, edge_modifier=Label("branch two"))
1706 x1.set_downstream(task_group1)
1707 x2.set_downstream(task_group1)
1708 task_group1.set_downstream(x3)
1711 :param tasks: Individual and/or list of tasks, EdgeModifiers, XComArgs, or TaskGroups to set dependencies
1712 """
1713 for index, up_task in enumerate(tasks[:-1]):
1714 down_task = tasks[index + 1]
1715 if isinstance(up_task, DependencyMixin):
1716 up_task.set_downstream(down_task)
1717 continue
1718 if isinstance(down_task, DependencyMixin):
1719 down_task.set_upstream(up_task)
1720 continue
1721 if not isinstance(up_task, Sequence) or not isinstance(down_task, Sequence):
1722 raise TypeError(f"Chain not supported between instances of {type(up_task)} and {type(down_task)}")
1723 up_task_list = up_task
1724 down_task_list = down_task
1725 if len(up_task_list) != len(down_task_list):
1726 raise AirflowException(
1727 f"Chain not supported for different length Iterable. "
1728 f"Got {len(up_task_list)} and {len(down_task_list)}."
1729 )
1730 for up_t, down_t in zip(up_task_list, down_task_list):
1731 up_t.set_downstream(down_t)
1734def cross_downstream(
1735 from_tasks: Sequence[DependencyMixin],
1736 to_tasks: DependencyMixin | Sequence[DependencyMixin],
1737):
1738 r"""
1739 Set downstream dependencies for all tasks in from_tasks to all tasks in to_tasks.
1741 Using classic operators/sensors:
1743 .. code-block:: python
1745 cross_downstream(from_tasks=[t1, t2, t3], to_tasks=[t4, t5, t6])
1747 is equivalent to::
1749 t1 ---> t4
1750 \ /
1751 t2 -X -> t5
1752 / \
1753 t3 ---> t6
1755 .. code-block:: python
1757 t1.set_downstream(t4)
1758 t1.set_downstream(t5)
1759 t1.set_downstream(t6)
1760 t2.set_downstream(t4)
1761 t2.set_downstream(t5)
1762 t2.set_downstream(t6)
1763 t3.set_downstream(t4)
1764 t3.set_downstream(t5)
1765 t3.set_downstream(t6)
1767 Using task-decorated functions aka XComArgs:
1769 .. code-block:: python
1771 cross_downstream(from_tasks=[x1(), x2(), x3()], to_tasks=[x4(), x5(), x6()])
1773 is equivalent to::
1775 x1 ---> x4
1776 \ /
1777 x2 -X -> x5
1778 / \
1779 x3 ---> x6
1781 .. code-block:: python
1783 x1 = x1()
1784 x2 = x2()
1785 x3 = x3()
1786 x4 = x4()
1787 x5 = x5()
1788 x6 = x6()
1789 x1.set_downstream(x4)
1790 x1.set_downstream(x5)
1791 x1.set_downstream(x6)
1792 x2.set_downstream(x4)
1793 x2.set_downstream(x5)
1794 x2.set_downstream(x6)
1795 x3.set_downstream(x4)
1796 x3.set_downstream(x5)
1797 x3.set_downstream(x6)
1799 It is also possible to mix between classic operator/sensor and XComArg tasks:
1801 .. code-block:: python
1803 cross_downstream(from_tasks=[t1, x2(), t3], to_tasks=[x1(), t2, x3()])
1805 is equivalent to::
1807 t1 ---> x1
1808 \ /
1809 x2 -X -> t2
1810 / \
1811 t3 ---> x3
1813 .. code-block:: python
1815 x1 = x1()
1816 x2 = x2()
1817 x3 = x3()
1818 t1.set_downstream(x1)
1819 t1.set_downstream(t2)
1820 t1.set_downstream(x3)
1821 x2.set_downstream(x1)
1822 x2.set_downstream(t2)
1823 x2.set_downstream(x3)
1824 t3.set_downstream(x1)
1825 t3.set_downstream(t2)
1826 t3.set_downstream(x3)
1828 :param from_tasks: List of tasks or XComArgs to start from.
1829 :param to_tasks: List of tasks or XComArgs to set as downstream dependencies.
1830 """
1831 for task in from_tasks:
1832 task.set_downstream(to_tasks)
1835# pyupgrade assumes all type annotations can be lazily evaluated, but this is
1836# not the case for attrs-decorated classes, since cattrs needs to evaluate the
1837# annotation expressions at runtime, and Python before 3.9.0 does not lazily
1838# evaluate those. Putting the expression in a top-level assignment statement
1839# communicates this runtime requirement to pyupgrade.
1840BaseOperatorClassList = List[Type[BaseOperator]]
1843@attr.s(auto_attribs=True)
1844class BaseOperatorLink(metaclass=ABCMeta):
1845 """Abstract base class that defines how we get an operator link."""
1847 operators: ClassVar[BaseOperatorClassList] = []
1848 """
1849 This property will be used by Airflow Plugins to find the Operators to which you want
1850 to assign this Operator Link
1852 :return: List of Operator classes used by task for which you want to create extra link
1853 """
1855 @property
1856 @abstractmethod
1857 def name(self) -> str:
1858 """Name of the link. This will be the button name on the task UI."""
1860 @abstractmethod
1861 def get_link(self, operator: BaseOperator, *, ti_key: TaskInstanceKey) -> str:
1862 """Link to external system.
1864 Note: The old signature of this function was ``(self, operator, dttm: datetime)``. That is still
1865 supported at runtime but is deprecated.
1867 :param operator: The Airflow operator object this link is associated to.
1868 :param ti_key: TaskInstance ID to return link for.
1869 :return: link to external system
1870 """