Coverage for /pythoncovmergedfiles/medio/medio/src/airflow/airflow/models/baseoperator.py: 53%
604 statements
« prev ^ index » next coverage.py v7.0.1, created at 2022-12-25 06:11 +0000
« prev ^ index » next coverage.py v7.0.1, created at 2022-12-25 06:11 +0000
1#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18"""Base operator for all operators."""
19from __future__ import annotations
21import abc
22import collections
23import collections.abc
24import contextlib
25import copy
26import functools
27import logging
28import sys
29import warnings
30from abc import ABCMeta, abstractmethod
31from datetime import datetime, timedelta
32from inspect import signature
33from types import FunctionType
34from typing import (
35 TYPE_CHECKING,
36 Any,
37 Callable,
38 ClassVar,
39 Collection,
40 Iterable,
41 List,
42 Sequence,
43 Type,
44 TypeVar,
45 Union,
46 cast,
47)
49import attr
50import pendulum
51from dateutil.relativedelta import relativedelta
52from sqlalchemy.orm import Session
53from sqlalchemy.orm.exc import NoResultFound
55from airflow.configuration import conf
56from airflow.exceptions import AirflowException, RemovedInAirflow3Warning, TaskDeferred
57from airflow.lineage import apply_lineage, prepare_lineage
58from airflow.models.abstractoperator import (
59 DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST,
60 DEFAULT_OWNER,
61 DEFAULT_POOL_SLOTS,
62 DEFAULT_PRIORITY_WEIGHT,
63 DEFAULT_QUEUE,
64 DEFAULT_RETRIES,
65 DEFAULT_RETRY_DELAY,
66 DEFAULT_TASK_EXECUTION_TIMEOUT,
67 DEFAULT_TRIGGER_RULE,
68 DEFAULT_WEIGHT_RULE,
69 AbstractOperator,
70 TaskStateChangeCallback,
71)
72from airflow.models.mappedoperator import OperatorPartial, validate_mapping_kwargs
73from airflow.models.param import ParamsDict
74from airflow.models.pool import Pool
75from airflow.models.taskinstance import TaskInstance, clear_task_instances
76from airflow.models.taskmixin import DAGNode, DependencyMixin
77from airflow.models.xcom import XCOM_RETURN_KEY
78from airflow.serialization.enums import DagAttributeTypes
79from airflow.ti_deps.deps.base_ti_dep import BaseTIDep
80from airflow.ti_deps.deps.not_in_retry_period_dep import NotInRetryPeriodDep
81from airflow.ti_deps.deps.not_previously_skipped_dep import NotPreviouslySkippedDep
82from airflow.ti_deps.deps.prev_dagrun_dep import PrevDagrunDep
83from airflow.ti_deps.deps.trigger_rule_dep import TriggerRuleDep
84from airflow.triggers.base import BaseTrigger
85from airflow.utils import timezone
86from airflow.utils.context import Context
87from airflow.utils.decorators import fixup_decorator_warning_stack
88from airflow.utils.helpers import validate_key
89from airflow.utils.operator_resources import Resources
90from airflow.utils.session import NEW_SESSION, provide_session
91from airflow.utils.trigger_rule import TriggerRule
92from airflow.utils.weight_rule import WeightRule
94if TYPE_CHECKING:
95 import jinja2 # Slow import.
97 from airflow.models.dag import DAG
98 from airflow.models.taskinstance import TaskInstanceKey
99 from airflow.models.xcom_arg import XComArg
100 from airflow.utils.task_group import TaskGroup
102ScheduleInterval = Union[str, timedelta, relativedelta]
104TaskPreExecuteHook = Callable[[Context], None]
105TaskPostExecuteHook = Callable[[Context, Any], None]
107T = TypeVar("T", bound=FunctionType)
109logger = logging.getLogger("airflow.models.baseoperator.BaseOperator")
112def parse_retries(retries: Any) -> int | None:
113 if retries is None or isinstance(retries, int):
114 return retries
115 try:
116 parsed_retries = int(retries)
117 except (TypeError, ValueError):
118 raise AirflowException(f"'retries' type must be int, not {type(retries).__name__}")
119 logger.warning("Implicitly converting 'retries' from %r to int", retries)
120 return parsed_retries
123def coerce_timedelta(value: float | timedelta, *, key: str) -> timedelta:
124 if isinstance(value, timedelta):
125 return value
126 logger.debug("%s isn't a timedelta object, assuming secs", key)
127 return timedelta(seconds=value)
130def coerce_resources(resources: dict[str, Any] | None) -> Resources | None:
131 if resources is None:
132 return None
133 return Resources(**resources)
136def _get_parent_defaults(dag: DAG | None, task_group: TaskGroup | None) -> tuple[dict, ParamsDict]:
137 if not dag:
138 return {}, ParamsDict()
139 dag_args = copy.copy(dag.default_args)
140 dag_params = copy.deepcopy(dag.params)
141 if task_group:
142 if task_group.default_args and not isinstance(task_group.default_args, collections.abc.Mapping):
143 raise TypeError("default_args must be a mapping")
144 dag_args.update(task_group.default_args)
145 return dag_args, dag_params
148def get_merged_defaults(
149 dag: DAG | None,
150 task_group: TaskGroup | None,
151 task_params: dict | None,
152 task_default_args: dict | None,
153) -> tuple[dict, ParamsDict]:
154 args, params = _get_parent_defaults(dag, task_group)
155 if task_params:
156 if not isinstance(task_params, collections.abc.Mapping):
157 raise TypeError("params must be a mapping")
158 params.update(task_params)
159 if task_default_args:
160 if not isinstance(task_default_args, collections.abc.Mapping):
161 raise TypeError("default_args must be a mapping")
162 args.update(task_default_args)
163 with contextlib.suppress(KeyError):
164 params.update(task_default_args["params"] or {})
165 return args, params
168class _PartialDescriptor:
169 """A descriptor that guards against ``.partial`` being called on Task objects."""
171 class_method = None
173 def __get__(
174 self, obj: BaseOperator, cls: type[BaseOperator] | None = None
175 ) -> Callable[..., OperatorPartial]:
176 # Call this "partial" so it looks nicer in stack traces.
177 def partial(**kwargs):
178 raise TypeError("partial can only be called on Operator classes, not Tasks themselves")
180 if obj is not None:
181 return partial
182 return self.class_method.__get__(cls, cls)
185# This is what handles the actual mapping.
186def partial(
187 operator_class: type[BaseOperator],
188 *,
189 task_id: str,
190 dag: DAG | None = None,
191 task_group: TaskGroup | None = None,
192 start_date: datetime | None = None,
193 end_date: datetime | None = None,
194 owner: str = DEFAULT_OWNER,
195 email: None | str | Iterable[str] = None,
196 params: dict | None = None,
197 resources: dict[str, Any] | None = None,
198 trigger_rule: str = DEFAULT_TRIGGER_RULE,
199 depends_on_past: bool = False,
200 ignore_first_depends_on_past: bool = DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST,
201 wait_for_downstream: bool = False,
202 retries: int | None = DEFAULT_RETRIES,
203 queue: str = DEFAULT_QUEUE,
204 pool: str | None = None,
205 pool_slots: int = DEFAULT_POOL_SLOTS,
206 execution_timeout: timedelta | None = DEFAULT_TASK_EXECUTION_TIMEOUT,
207 max_retry_delay: None | timedelta | float = None,
208 retry_delay: timedelta | float = DEFAULT_RETRY_DELAY,
209 retry_exponential_backoff: bool = False,
210 priority_weight: int = DEFAULT_PRIORITY_WEIGHT,
211 weight_rule: str = DEFAULT_WEIGHT_RULE,
212 sla: timedelta | None = None,
213 max_active_tis_per_dag: int | None = None,
214 on_execute_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = None,
215 on_failure_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = None,
216 on_success_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = None,
217 on_retry_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = None,
218 run_as_user: str | None = None,
219 executor_config: dict | None = None,
220 inlets: Any | None = None,
221 outlets: Any | None = None,
222 doc: str | None = None,
223 doc_md: str | None = None,
224 doc_json: str | None = None,
225 doc_yaml: str | None = None,
226 doc_rst: str | None = None,
227 **kwargs,
228) -> OperatorPartial:
229 from airflow.models.dag import DagContext
230 from airflow.utils.task_group import TaskGroupContext
232 validate_mapping_kwargs(operator_class, "partial", kwargs)
234 dag = dag or DagContext.get_current_dag()
235 if dag:
236 task_group = TaskGroupContext.get_current_task_group(dag)
237 if task_group:
238 task_id = task_group.child_id(task_id)
240 # Merge DAG and task group level defaults into user-supplied values.
241 partial_kwargs, partial_params = get_merged_defaults(
242 dag=dag,
243 task_group=task_group,
244 task_params=params,
245 task_default_args=kwargs.pop("default_args", None),
246 )
247 partial_kwargs.update(kwargs)
249 # Always fully populate partial kwargs to exclude them from map().
250 partial_kwargs.setdefault("dag", dag)
251 partial_kwargs.setdefault("task_group", task_group)
252 partial_kwargs.setdefault("task_id", task_id)
253 partial_kwargs.setdefault("start_date", start_date)
254 partial_kwargs.setdefault("end_date", end_date)
255 partial_kwargs.setdefault("owner", owner)
256 partial_kwargs.setdefault("email", email)
257 partial_kwargs.setdefault("trigger_rule", trigger_rule)
258 partial_kwargs.setdefault("depends_on_past", depends_on_past)
259 partial_kwargs.setdefault("ignore_first_depends_on_past", ignore_first_depends_on_past)
260 partial_kwargs.setdefault("wait_for_downstream", wait_for_downstream)
261 partial_kwargs.setdefault("retries", retries)
262 partial_kwargs.setdefault("queue", queue)
263 partial_kwargs.setdefault("pool", pool)
264 partial_kwargs.setdefault("pool_slots", pool_slots)
265 partial_kwargs.setdefault("execution_timeout", execution_timeout)
266 partial_kwargs.setdefault("max_retry_delay", max_retry_delay)
267 partial_kwargs.setdefault("retry_delay", retry_delay)
268 partial_kwargs.setdefault("retry_exponential_backoff", retry_exponential_backoff)
269 partial_kwargs.setdefault("priority_weight", priority_weight)
270 partial_kwargs.setdefault("weight_rule", weight_rule)
271 partial_kwargs.setdefault("sla", sla)
272 partial_kwargs.setdefault("max_active_tis_per_dag", max_active_tis_per_dag)
273 partial_kwargs.setdefault("on_execute_callback", on_execute_callback)
274 partial_kwargs.setdefault("on_failure_callback", on_failure_callback)
275 partial_kwargs.setdefault("on_retry_callback", on_retry_callback)
276 partial_kwargs.setdefault("on_success_callback", on_success_callback)
277 partial_kwargs.setdefault("run_as_user", run_as_user)
278 partial_kwargs.setdefault("executor_config", executor_config)
279 partial_kwargs.setdefault("inlets", inlets or [])
280 partial_kwargs.setdefault("outlets", outlets or [])
281 partial_kwargs.setdefault("resources", resources)
282 partial_kwargs.setdefault("doc", doc)
283 partial_kwargs.setdefault("doc_json", doc_json)
284 partial_kwargs.setdefault("doc_md", doc_md)
285 partial_kwargs.setdefault("doc_rst", doc_rst)
286 partial_kwargs.setdefault("doc_yaml", doc_yaml)
288 # Post-process arguments. Should be kept in sync with _TaskDecorator.expand().
289 if "task_concurrency" in kwargs: # Reject deprecated option.
290 raise TypeError("unexpected argument: task_concurrency")
291 if partial_kwargs["wait_for_downstream"]:
292 partial_kwargs["depends_on_past"] = True
293 partial_kwargs["start_date"] = timezone.convert_to_utc(partial_kwargs["start_date"])
294 partial_kwargs["end_date"] = timezone.convert_to_utc(partial_kwargs["end_date"])
295 if partial_kwargs["pool"] is None:
296 partial_kwargs["pool"] = Pool.DEFAULT_POOL_NAME
297 partial_kwargs["retries"] = parse_retries(partial_kwargs["retries"])
298 partial_kwargs["retry_delay"] = coerce_timedelta(partial_kwargs["retry_delay"], key="retry_delay")
299 if partial_kwargs["max_retry_delay"] is not None:
300 partial_kwargs["max_retry_delay"] = coerce_timedelta(
301 partial_kwargs["max_retry_delay"],
302 key="max_retry_delay",
303 )
304 partial_kwargs["executor_config"] = partial_kwargs["executor_config"] or {}
305 partial_kwargs["resources"] = coerce_resources(partial_kwargs["resources"])
307 return OperatorPartial(
308 operator_class=operator_class,
309 kwargs=partial_kwargs,
310 params=partial_params,
311 )
314class BaseOperatorMeta(abc.ABCMeta):
315 """Metaclass of BaseOperator."""
317 @classmethod
318 def _apply_defaults(cls, func: T) -> T:
319 """
320 Function decorator that Looks for an argument named "default_args", and
321 fills the unspecified arguments from it.
323 Since python2.* isn't clear about which arguments are missing when
324 calling a function, and that this can be quite confusing with multi-level
325 inheritance and argument defaults, this decorator also alerts with
326 specific information about the missing arguments.
327 """
328 # Cache inspect.signature for the wrapper closure to avoid calling it
329 # at every decorated invocation. This is separate sig_cache created
330 # per decoration, i.e. each function decorated using apply_defaults will
331 # have a different sig_cache.
332 sig_cache = signature(func)
333 non_variadic_params = {
334 name: param
335 for (name, param) in sig_cache.parameters.items()
336 if param.name != "self" and param.kind not in (param.VAR_POSITIONAL, param.VAR_KEYWORD)
337 }
338 non_optional_args = {
339 name
340 for name, param in non_variadic_params.items()
341 if param.default == param.empty and name != "task_id"
342 }
344 fixup_decorator_warning_stack(func)
346 @functools.wraps(func)
347 def apply_defaults(self: BaseOperator, *args: Any, **kwargs: Any) -> Any:
348 from airflow.models.dag import DagContext
349 from airflow.utils.task_group import TaskGroupContext
351 if len(args) > 0:
352 raise AirflowException("Use keyword arguments when initializing operators")
354 instantiated_from_mapped = kwargs.pop(
355 "_airflow_from_mapped",
356 getattr(self, "_BaseOperator__from_mapped", False),
357 )
359 dag: DAG | None = kwargs.get("dag") or DagContext.get_current_dag()
360 task_group: TaskGroup | None = kwargs.get("task_group")
361 if dag and not task_group:
362 task_group = TaskGroupContext.get_current_task_group(dag)
364 default_args, merged_params = get_merged_defaults(
365 dag=dag,
366 task_group=task_group,
367 task_params=kwargs.pop("params", None),
368 task_default_args=kwargs.pop("default_args", None),
369 )
371 for arg in sig_cache.parameters:
372 if arg not in kwargs and arg in default_args:
373 kwargs[arg] = default_args[arg]
375 missing_args = non_optional_args - set(kwargs)
376 if len(missing_args) == 1:
377 raise AirflowException(f"missing keyword argument {missing_args.pop()!r}")
378 elif missing_args:
379 display = ", ".join(repr(a) for a in sorted(missing_args))
380 raise AirflowException(f"missing keyword arguments {display}")
382 if merged_params:
383 kwargs["params"] = merged_params
385 hook = getattr(self, "_hook_apply_defaults", None)
386 if hook:
387 args, kwargs = hook(**kwargs, default_args=default_args)
388 default_args = kwargs.pop("default_args", {})
390 if not hasattr(self, "_BaseOperator__init_kwargs"):
391 self._BaseOperator__init_kwargs = {}
392 self._BaseOperator__from_mapped = instantiated_from_mapped
394 result = func(self, **kwargs, default_args=default_args)
396 # Store the args passed to init -- we need them to support task.map serialzation!
397 self._BaseOperator__init_kwargs.update(kwargs) # type: ignore
399 # Set upstream task defined by XComArgs passed to template fields of the operator.
400 # BUT: only do this _ONCE_, not once for each class in the hierarchy
401 if not instantiated_from_mapped and func == self.__init__.__wrapped__: # type: ignore[misc]
402 self.set_xcomargs_dependencies()
403 # Mark instance as instantiated.
404 self._BaseOperator__instantiated = True
406 return result
408 apply_defaults.__non_optional_args = non_optional_args # type: ignore
409 apply_defaults.__param_names = set(non_variadic_params) # type: ignore
411 return cast(T, apply_defaults)
413 def __new__(cls, name, bases, namespace, **kwargs):
414 new_cls = super().__new__(cls, name, bases, namespace, **kwargs)
415 with contextlib.suppress(KeyError):
416 # Update the partial descriptor with the class method so it call call the actual function (but let
417 # subclasses override it if they need to)
418 partial_desc = vars(new_cls)["partial"]
419 if isinstance(partial_desc, _PartialDescriptor):
420 partial_desc.class_method = classmethod(partial)
421 new_cls.__init__ = cls._apply_defaults(new_cls.__init__)
422 return new_cls
425@functools.total_ordering
426class BaseOperator(AbstractOperator, metaclass=BaseOperatorMeta):
427 """
428 Abstract base class for all operators. Since operators create objects that
429 become nodes in the dag, BaseOperator contains many recursive methods for
430 dag crawling behavior. To derive this class, you are expected to override
431 the constructor as well as the 'execute' method.
433 Operators derived from this class should perform or trigger certain tasks
434 synchronously (wait for completion). Example of operators could be an
435 operator that runs a Pig job (PigOperator), a sensor operator that
436 waits for a partition to land in Hive (HiveSensorOperator), or one that
437 moves data from Hive to MySQL (Hive2MySqlOperator). Instances of these
438 operators (tasks) target specific operations, running specific scripts,
439 functions or data transfers.
441 This class is abstract and shouldn't be instantiated. Instantiating a
442 class derived from this one results in the creation of a task object,
443 which ultimately becomes a node in DAG objects. Task dependencies should
444 be set by using the set_upstream and/or set_downstream methods.
446 :param task_id: a unique, meaningful id for the task
447 :param owner: the owner of the task. Using a meaningful description
448 (e.g. user/person/team/role name) to clarify ownership is recommended.
449 :param email: the 'to' email address(es) used in email alerts. This can be a
450 single email or multiple ones. Multiple addresses can be specified as a
451 comma or semi-colon separated string or by passing a list of strings.
452 :param email_on_retry: Indicates whether email alerts should be sent when a
453 task is retried
454 :param email_on_failure: Indicates whether email alerts should be sent when
455 a task failed
456 :param retries: the number of retries that should be performed before
457 failing the task
458 :param retry_delay: delay between retries, can be set as ``timedelta`` or
459 ``float`` seconds, which will be converted into ``timedelta``,
460 the default is ``timedelta(seconds=300)``.
461 :param retry_exponential_backoff: allow progressively longer waits between
462 retries by using exponential backoff algorithm on retry delay (delay
463 will be converted into seconds)
464 :param max_retry_delay: maximum delay interval between retries, can be set as
465 ``timedelta`` or ``float`` seconds, which will be converted into ``timedelta``.
466 :param start_date: The ``start_date`` for the task, determines
467 the ``execution_date`` for the first task instance. The best practice
468 is to have the start_date rounded
469 to your DAG's ``schedule_interval``. Daily jobs have their start_date
470 some day at 00:00:00, hourly jobs have their start_date at 00:00
471 of a specific hour. Note that Airflow simply looks at the latest
472 ``execution_date`` and adds the ``schedule_interval`` to determine
473 the next ``execution_date``. It is also very important
474 to note that different tasks' dependencies
475 need to line up in time. If task A depends on task B and their
476 start_date are offset in a way that their execution_date don't line
477 up, A's dependencies will never be met. If you are looking to delay
478 a task, for example running a daily task at 2AM, look into the
479 ``TimeSensor`` and ``TimeDeltaSensor``. We advise against using
480 dynamic ``start_date`` and recommend using fixed ones. Read the
481 FAQ entry about start_date for more information.
482 :param end_date: if specified, the scheduler won't go beyond this date
483 :param depends_on_past: when set to true, task instances will run
484 sequentially and only if the previous instance has succeeded or has been skipped.
485 The task instance for the start_date is allowed to run.
486 :param wait_for_downstream: when set to true, an instance of task
487 X will wait for tasks immediately downstream of the previous instance
488 of task X to finish successfully or be skipped before it runs. This is useful if the
489 different instances of a task X alter the same asset, and this asset
490 is used by tasks downstream of task X. Note that depends_on_past
491 is forced to True wherever wait_for_downstream is used. Also note that
492 only tasks *immediately* downstream of the previous task instance are waited
493 for; the statuses of any tasks further downstream are ignored.
494 :param dag: a reference to the dag the task is attached to (if any)
495 :param priority_weight: priority weight of this task against other task.
496 This allows the executor to trigger higher priority tasks before
497 others when things get backed up. Set priority_weight as a higher
498 number for more important tasks.
499 :param weight_rule: weighting method used for the effective total
500 priority weight of the task. Options are:
501 ``{ downstream | upstream | absolute }`` default is ``downstream``
502 When set to ``downstream`` the effective weight of the task is the
503 aggregate sum of all downstream descendants. As a result, upstream
504 tasks will have higher weight and will be scheduled more aggressively
505 when using positive weight values. This is useful when you have
506 multiple dag run instances and desire to have all upstream tasks to
507 complete for all runs before each dag can continue processing
508 downstream tasks. When set to ``upstream`` the effective weight is the
509 aggregate sum of all upstream ancestors. This is the opposite where
510 downstream tasks have higher weight and will be scheduled more
511 aggressively when using positive weight values. This is useful when you
512 have multiple dag run instances and prefer to have each dag complete
513 before starting upstream tasks of other dags. When set to
514 ``absolute``, the effective weight is the exact ``priority_weight``
515 specified without additional weighting. You may want to do this when
516 you know exactly what priority weight each task should have.
517 Additionally, when set to ``absolute``, there is bonus effect of
518 significantly speeding up the task creation process as for very large
519 DAGs. Options can be set as string or using the constants defined in
520 the static class ``airflow.utils.WeightRule``
521 :param queue: which queue to target when running this job. Not
522 all executors implement queue management, the CeleryExecutor
523 does support targeting specific queues.
524 :param pool: the slot pool this task should run in, slot pools are a
525 way to limit concurrency for certain tasks
526 :param pool_slots: the number of pool slots this task should use (>= 1)
527 Values less than 1 are not allowed.
528 :param sla: time by which the job is expected to succeed. Note that
529 this represents the ``timedelta`` after the period is closed. For
530 example if you set an SLA of 1 hour, the scheduler would send an email
531 soon after 1:00AM on the ``2016-01-02`` if the ``2016-01-01`` instance
532 has not succeeded yet.
533 The scheduler pays special attention for jobs with an SLA and
534 sends alert
535 emails for SLA misses. SLA misses are also recorded in the database
536 for future reference. All tasks that share the same SLA time
537 get bundled in a single email, sent soon after that time. SLA
538 notification are sent once and only once for each task instance.
539 :param execution_timeout: max time allowed for the execution of
540 this task instance, if it goes beyond it will raise and fail.
541 :param on_failure_callback: a function or list of functions to be called when a task instance
542 of this task fails. a context dictionary is passed as a single
543 parameter to this function. Context contains references to related
544 objects to the task instance and is documented under the macros
545 section of the API.
546 :param on_execute_callback: much like the ``on_failure_callback`` except
547 that it is executed right before the task is executed.
548 :param on_retry_callback: much like the ``on_failure_callback`` except
549 that it is executed when retries occur.
550 :param on_success_callback: much like the ``on_failure_callback`` except
551 that it is executed when the task succeeds.
552 :param pre_execute: a function to be called immediately before task
553 execution, receiving a context dictionary; raising an exception will
554 prevent the task from being executed.
556 |experimental|
557 :param post_execute: a function to be called immediately after task
558 execution, receiving a context dictionary and task result; raising an
559 exception will prevent the task from succeeding.
561 |experimental|
562 :param trigger_rule: defines the rule by which dependencies are applied
563 for the task to get triggered. Options are:
564 ``{ all_success | all_failed | all_done | all_skipped | one_success | one_done |
565 one_failed | none_failed | none_failed_min_one_success | none_skipped | always}``
566 default is ``all_success``. Options can be set as string or
567 using the constants defined in the static class
568 ``airflow.utils.TriggerRule``
569 :param resources: A map of resource parameter names (the argument names of the
570 Resources constructor) to their values.
571 :param run_as_user: unix username to impersonate while running the task
572 :param max_active_tis_per_dag: When set, a task will be able to limit the concurrent
573 runs across execution_dates.
574 :param executor_config: Additional task-level configuration parameters that are
575 interpreted by a specific executor. Parameters are namespaced by the name of
576 executor.
578 **Example**: to run this task in a specific docker container through
579 the KubernetesExecutor ::
581 MyOperator(...,
582 executor_config={
583 "KubernetesExecutor":
584 {"image": "myCustomDockerImage"}
585 }
586 )
588 :param do_xcom_push: if True, an XCom is pushed containing the Operator's
589 result
590 :param task_group: The TaskGroup to which the task should belong. This is typically provided when not
591 using a TaskGroup as a context manager.
592 :param doc: Add documentation or notes to your Task objects that is visible in
593 Task Instance details View in the Webserver
594 :param doc_md: Add documentation (in Markdown format) or notes to your Task objects
595 that is visible in Task Instance details View in the Webserver
596 :param doc_rst: Add documentation (in RST format) or notes to your Task objects
597 that is visible in Task Instance details View in the Webserver
598 :param doc_json: Add documentation (in JSON format) or notes to your Task objects
599 that is visible in Task Instance details View in the Webserver
600 :param doc_yaml: Add documentation (in YAML format) or notes to your Task objects
601 that is visible in Task Instance details View in the Webserver
602 """
604 # Implementing Operator.
605 template_fields: Sequence[str] = ()
606 template_ext: Sequence[str] = ()
608 template_fields_renderers: dict[str, str] = {}
610 # Defines the color in the UI
611 ui_color: str = "#fff"
612 ui_fgcolor: str = "#000"
614 pool: str = ""
616 # base list which includes all the attrs that don't need deep copy.
617 _base_operator_shallow_copy_attrs: tuple[str, ...] = (
618 "user_defined_macros",
619 "user_defined_filters",
620 "params",
621 "_log",
622 )
624 # each operator should override this class attr for shallow copy attrs.
625 shallow_copy_attrs: Sequence[str] = ()
627 # Defines the operator level extra links
628 operator_extra_links: Collection[BaseOperatorLink] = ()
630 # The _serialized_fields are lazily loaded when get_serialized_fields() method is called
631 __serialized_fields: frozenset[str] | None = None
633 partial: Callable[..., OperatorPartial] = _PartialDescriptor() # type: ignore
635 _comps = {
636 "task_id",
637 "dag_id",
638 "owner",
639 "email",
640 "email_on_retry",
641 "retry_delay",
642 "retry_exponential_backoff",
643 "max_retry_delay",
644 "start_date",
645 "end_date",
646 "depends_on_past",
647 "wait_for_downstream",
648 "priority_weight",
649 "sla",
650 "execution_timeout",
651 "on_execute_callback",
652 "on_failure_callback",
653 "on_success_callback",
654 "on_retry_callback",
655 "do_xcom_push",
656 }
658 # Defines if the operator supports lineage without manual definitions
659 supports_lineage = False
661 # If True then the class constructor was called
662 __instantiated = False
663 # List of args as passed to `init()`, after apply_defaults() has been updated. Used to "recreate" the task
664 # when mapping
665 __init_kwargs: dict[str, Any]
667 # Set to True before calling execute method
668 _lock_for_execution = False
670 _dag: DAG | None = None
671 task_group: TaskGroup | None = None
673 # subdag parameter is only set for SubDagOperator.
674 # Setting it to None by default as other Operators do not have that field
675 subdag: DAG | None = None
677 start_date: pendulum.DateTime | None = None
678 end_date: pendulum.DateTime | None = None
680 # Set to True for an operator instantiated by a mapped operator.
681 __from_mapped = False
683 def __init__(
684 self,
685 task_id: str,
686 owner: str = DEFAULT_OWNER,
687 email: str | Iterable[str] | None = None,
688 email_on_retry: bool = conf.getboolean("email", "default_email_on_retry", fallback=True),
689 email_on_failure: bool = conf.getboolean("email", "default_email_on_failure", fallback=True),
690 retries: int | None = DEFAULT_RETRIES,
691 retry_delay: timedelta | float = DEFAULT_RETRY_DELAY,
692 retry_exponential_backoff: bool = False,
693 max_retry_delay: timedelta | float | None = None,
694 start_date: datetime | None = None,
695 end_date: datetime | None = None,
696 depends_on_past: bool = False,
697 ignore_first_depends_on_past: bool = DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST,
698 wait_for_downstream: bool = False,
699 dag: DAG | None = None,
700 params: dict | None = None,
701 default_args: dict | None = None,
702 priority_weight: int = DEFAULT_PRIORITY_WEIGHT,
703 weight_rule: str = DEFAULT_WEIGHT_RULE,
704 queue: str = DEFAULT_QUEUE,
705 pool: str | None = None,
706 pool_slots: int = DEFAULT_POOL_SLOTS,
707 sla: timedelta | None = None,
708 execution_timeout: timedelta | None = DEFAULT_TASK_EXECUTION_TIMEOUT,
709 on_execute_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = None,
710 on_failure_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = None,
711 on_success_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = None,
712 on_retry_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = None,
713 pre_execute: TaskPreExecuteHook | None = None,
714 post_execute: TaskPostExecuteHook | None = None,
715 trigger_rule: str = DEFAULT_TRIGGER_RULE,
716 resources: dict[str, Any] | None = None,
717 run_as_user: str | None = None,
718 task_concurrency: int | None = None,
719 max_active_tis_per_dag: int | None = None,
720 executor_config: dict | None = None,
721 do_xcom_push: bool = True,
722 inlets: Any | None = None,
723 outlets: Any | None = None,
724 task_group: TaskGroup | None = None,
725 doc: str | None = None,
726 doc_md: str | None = None,
727 doc_json: str | None = None,
728 doc_yaml: str | None = None,
729 doc_rst: str | None = None,
730 **kwargs,
731 ):
732 from airflow.models.dag import DagContext
733 from airflow.utils.task_group import TaskGroupContext
735 self.__init_kwargs = {}
737 super().__init__()
739 kwargs.pop("_airflow_mapped_validation_only", None)
740 if kwargs:
741 if not conf.getboolean("operators", "ALLOW_ILLEGAL_ARGUMENTS"):
742 raise AirflowException(
743 f"Invalid arguments were passed to {self.__class__.__name__} (task_id: {task_id}). "
744 f"Invalid arguments were:\n**kwargs: {kwargs}",
745 )
746 warnings.warn(
747 f"Invalid arguments were passed to {self.__class__.__name__} (task_id: {task_id}). "
748 "Support for passing such arguments will be dropped in future. "
749 f"Invalid arguments were:\n**kwargs: {kwargs}",
750 category=RemovedInAirflow3Warning,
751 stacklevel=3,
752 )
753 validate_key(task_id)
755 dag = dag or DagContext.get_current_dag()
756 task_group = task_group or TaskGroupContext.get_current_task_group(dag)
758 self.task_id = task_group.child_id(task_id) if task_group else task_id
759 if not self.__from_mapped and task_group:
760 task_group.add(self)
762 self.owner = owner
763 self.email = email
764 self.email_on_retry = email_on_retry
765 self.email_on_failure = email_on_failure
767 if execution_timeout is not None and not isinstance(execution_timeout, timedelta):
768 raise ValueError(
769 f"execution_timeout must be timedelta object but passed as type: {type(execution_timeout)}"
770 )
771 self.execution_timeout = execution_timeout
773 self.on_execute_callback = on_execute_callback
774 self.on_failure_callback = on_failure_callback
775 self.on_success_callback = on_success_callback
776 self.on_retry_callback = on_retry_callback
777 self._pre_execute_hook = pre_execute
778 self._post_execute_hook = post_execute
780 if start_date and not isinstance(start_date, datetime):
781 self.log.warning("start_date for %s isn't datetime.datetime", self)
782 elif start_date:
783 self.start_date = timezone.convert_to_utc(start_date)
785 if end_date:
786 self.end_date = timezone.convert_to_utc(end_date)
788 self.executor_config = executor_config or {}
789 self.run_as_user = run_as_user
790 self.retries = parse_retries(retries)
791 self.queue = queue
792 self.pool = Pool.DEFAULT_POOL_NAME if pool is None else pool
793 self.pool_slots = pool_slots
794 if self.pool_slots < 1:
795 dag_str = f" in dag {dag.dag_id}" if dag else ""
796 raise ValueError(f"pool slots for {self.task_id}{dag_str} cannot be less than 1")
797 self.sla = sla
799 if trigger_rule == "dummy":
800 warnings.warn(
801 "dummy Trigger Rule is deprecated. Please use `TriggerRule.ALWAYS`.",
802 RemovedInAirflow3Warning,
803 stacklevel=2,
804 )
805 trigger_rule = TriggerRule.ALWAYS
807 if trigger_rule == "none_failed_or_skipped":
808 warnings.warn(
809 "none_failed_or_skipped Trigger Rule is deprecated. "
810 "Please use `none_failed_min_one_success`.",
811 RemovedInAirflow3Warning,
812 stacklevel=2,
813 )
814 trigger_rule = TriggerRule.NONE_FAILED_MIN_ONE_SUCCESS
816 if not TriggerRule.is_valid(trigger_rule):
817 raise AirflowException(
818 f"The trigger_rule must be one of {TriggerRule.all_triggers()},"
819 f"'{dag.dag_id if dag else ''}.{task_id}'; received '{trigger_rule}'."
820 )
822 self.trigger_rule: TriggerRule = TriggerRule(trigger_rule)
823 self.depends_on_past: bool = depends_on_past
824 self.ignore_first_depends_on_past: bool = ignore_first_depends_on_past
825 self.wait_for_downstream: bool = wait_for_downstream
826 if wait_for_downstream:
827 self.depends_on_past = True
829 self.retry_delay = coerce_timedelta(retry_delay, key="retry_delay")
830 self.retry_exponential_backoff = retry_exponential_backoff
831 self.max_retry_delay = (
832 max_retry_delay
833 if max_retry_delay is None
834 else coerce_timedelta(max_retry_delay, key="max_retry_delay")
835 )
837 # At execution_time this becomes a normal dict
838 self.params: ParamsDict | dict = ParamsDict(params)
839 if priority_weight is not None and not isinstance(priority_weight, int):
840 raise AirflowException(
841 f"`priority_weight` for task '{self.task_id}' only accepts integers, "
842 f"received '{type(priority_weight)}'."
843 )
844 self.priority_weight = priority_weight
845 if not WeightRule.is_valid(weight_rule):
846 raise AirflowException(
847 f"The weight_rule must be one of "
848 f"{WeightRule.all_weight_rules},'{dag.dag_id if dag else ''}.{task_id}'; "
849 f"received '{weight_rule}'."
850 )
851 self.weight_rule = weight_rule
852 self.resources = coerce_resources(resources)
853 if task_concurrency and not max_active_tis_per_dag:
854 # TODO: Remove in Airflow 3.0
855 warnings.warn(
856 "The 'task_concurrency' parameter is deprecated. Please use 'max_active_tis_per_dag'.",
857 RemovedInAirflow3Warning,
858 stacklevel=2,
859 )
860 max_active_tis_per_dag = task_concurrency
861 self.max_active_tis_per_dag: int | None = max_active_tis_per_dag
862 self.do_xcom_push = do_xcom_push
864 self.doc_md = doc_md
865 self.doc_json = doc_json
866 self.doc_yaml = doc_yaml
867 self.doc_rst = doc_rst
868 self.doc = doc
870 self.upstream_task_ids: set[str] = set()
871 self.downstream_task_ids: set[str] = set()
873 if dag:
874 self.dag = dag
876 self._log = logging.getLogger("airflow.task.operators")
878 # Lineage
879 self.inlets: list = []
880 self.outlets: list = []
882 if inlets:
883 self.inlets = (
884 inlets
885 if isinstance(inlets, list)
886 else [
887 inlets,
888 ]
889 )
891 if outlets:
892 self.outlets = (
893 outlets
894 if isinstance(outlets, list)
895 else [
896 outlets,
897 ]
898 )
900 if isinstance(self.template_fields, str):
901 warnings.warn(
902 f"The `template_fields` value for {self.task_type} is a string "
903 "but should be a list or tuple of string. Wrapping it in a list for execution. "
904 f"Please update {self.task_type} accordingly.",
905 UserWarning,
906 stacklevel=2,
907 )
908 self.template_fields = [self.template_fields]
910 def __eq__(self, other):
911 if type(self) is type(other):
912 # Use getattr() instead of __dict__ as __dict__ doesn't return
913 # correct values for properties.
914 return all(getattr(self, c, None) == getattr(other, c, None) for c in self._comps)
915 return False
917 def __ne__(self, other):
918 return not self == other
920 def __hash__(self):
921 hash_components = [type(self)]
922 for component in self._comps:
923 val = getattr(self, component, None)
924 try:
925 hash(val)
926 hash_components.append(val)
927 except TypeError:
928 hash_components.append(repr(val))
929 return hash(tuple(hash_components))
931 # including lineage information
932 def __or__(self, other):
933 """
934 Called for [This Operator] | [Operator], The inlets of other
935 will be set to pickup the outlets from this operator. Other will
936 be set as a downstream task of this operator.
937 """
938 if isinstance(other, BaseOperator):
939 if not self.outlets and not self.supports_lineage:
940 raise ValueError("No outlets defined for this operator")
941 other.add_inlets([self.task_id])
942 self.set_downstream(other)
943 else:
944 raise TypeError(f"Right hand side ({other}) is not an Operator")
946 return self
948 # /Composing Operators ---------------------------------------------
950 def __gt__(self, other):
951 """
952 Called for [Operator] > [Outlet], so that if other is an attr annotated object
953 it is set as an outlet of this Operator.
954 """
955 if not isinstance(other, Iterable):
956 other = [other]
958 for obj in other:
959 if not attr.has(obj):
960 raise TypeError(f"Left hand side ({obj}) is not an outlet")
961 self.add_outlets(other)
963 return self
965 def __lt__(self, other):
966 """
967 Called for [Inlet] > [Operator] or [Operator] < [Inlet], so that if other is
968 an attr annotated object it is set as an inlet to this operator
969 """
970 if not isinstance(other, Iterable):
971 other = [other]
973 for obj in other:
974 if not attr.has(obj):
975 raise TypeError(f"{obj} cannot be an inlet")
976 self.add_inlets(other)
978 return self
980 def __setattr__(self, key, value):
981 super().__setattr__(key, value)
982 if self.__from_mapped or self._lock_for_execution:
983 return # Skip any custom behavior for validation and during execute.
984 if key in self.__init_kwargs:
985 self.__init_kwargs[key] = value
986 if self.__instantiated and key in self.template_fields:
987 # Resolve upstreams set by assigning an XComArg after initializing
988 # an operator, example:
989 # op = BashOperator()
990 # op.bash_command = "sleep 1"
991 self.set_xcomargs_dependencies()
993 def add_inlets(self, inlets: Iterable[Any]):
994 """Sets inlets to this operator"""
995 self.inlets.extend(inlets)
997 def add_outlets(self, outlets: Iterable[Any]):
998 """Defines the outlets of this operator"""
999 self.outlets.extend(outlets)
1001 def get_inlet_defs(self):
1002 """:meta private:"""
1003 return self.inlets
1005 def get_outlet_defs(self):
1006 """:meta private:"""
1007 return self.outlets
1009 def get_dag(self) -> DAG | None:
1010 return self._dag
1012 @property # type: ignore[override]
1013 def dag(self) -> DAG: # type: ignore[override]
1014 """Returns the Operator's DAG if set, otherwise raises an error"""
1015 if self._dag:
1016 return self._dag
1017 else:
1018 raise AirflowException(f"Operator {self} has not been assigned to a DAG yet")
1020 @dag.setter
1021 def dag(self, dag: DAG | None):
1022 """
1023 Operators can be assigned to one DAG, one time. Repeat assignments to
1024 that same DAG are ok.
1025 """
1026 from airflow.models.dag import DAG
1028 if dag is None:
1029 self._dag = None
1030 return
1031 if not isinstance(dag, DAG):
1032 raise TypeError(f"Expected DAG; received {dag.__class__.__name__}")
1033 elif self.has_dag() and self.dag is not dag:
1034 raise AirflowException(f"The DAG assigned to {self} can not be changed.")
1036 if self.__from_mapped:
1037 pass # Don't add to DAG -- the mapped task takes the place.
1038 elif self.task_id not in dag.task_dict:
1039 dag.add_task(self)
1040 elif self.task_id in dag.task_dict and dag.task_dict[self.task_id] is not self:
1041 dag.add_task(self)
1043 self._dag = dag
1045 def has_dag(self):
1046 """Returns True if the Operator has been assigned to a DAG."""
1047 return self._dag is not None
1049 deps: frozenset[BaseTIDep] = frozenset(
1050 {
1051 NotInRetryPeriodDep(),
1052 PrevDagrunDep(),
1053 TriggerRuleDep(),
1054 NotPreviouslySkippedDep(),
1055 }
1056 )
1057 """
1058 Returns the set of dependencies for the operator. These differ from execution
1059 context dependencies in that they are specific to tasks and can be
1060 extended/overridden by subclasses.
1061 """
1063 def prepare_for_execution(self) -> BaseOperator:
1064 """
1065 Lock task for execution to disable custom action in __setattr__ and
1066 returns a copy of the task
1067 """
1068 other = copy.copy(self)
1069 other._lock_for_execution = True
1070 return other
1072 def set_xcomargs_dependencies(self) -> None:
1073 """
1074 Resolves upstream dependencies of a task. In this way passing an ``XComArg``
1075 as value for a template field will result in creating upstream relation between
1076 two tasks.
1078 **Example**: ::
1080 with DAG(...):
1081 generate_content = GenerateContentOperator(task_id="generate_content")
1082 send_email = EmailOperator(..., html_content=generate_content.output)
1084 # This is equivalent to
1085 with DAG(...):
1086 generate_content = GenerateContentOperator(task_id="generate_content")
1087 send_email = EmailOperator(
1088 ..., html_content="{{ task_instance.xcom_pull('generate_content') }}"
1089 )
1090 generate_content >> send_email
1092 """
1093 from airflow.models.xcom_arg import XComArg
1095 for field in self.template_fields:
1096 if hasattr(self, field):
1097 arg = getattr(self, field)
1098 XComArg.apply_upstream_relationship(self, arg)
1100 @prepare_lineage
1101 def pre_execute(self, context: Any):
1102 """This hook is triggered right before self.execute() is called."""
1103 if self._pre_execute_hook is not None:
1104 self._pre_execute_hook(context)
1106 def execute(self, context: Context) -> Any:
1107 """
1108 This is the main method to derive when creating an operator.
1109 Context is the same dictionary used as when rendering jinja templates.
1111 Refer to get_template_context for more context.
1112 """
1113 raise NotImplementedError()
1115 @apply_lineage
1116 def post_execute(self, context: Any, result: Any = None):
1117 """
1118 This hook is triggered right after self.execute() is called.
1119 It is passed the execution context and any results returned by the
1120 operator.
1121 """
1122 if self._post_execute_hook is not None:
1123 self._post_execute_hook(context, result)
1125 def on_kill(self) -> None:
1126 """
1127 Override this method to cleanup subprocesses when a task instance
1128 gets killed. Any use of the threading, subprocess or multiprocessing
1129 module within an operator needs to be cleaned up or it will leave
1130 ghost processes behind.
1131 """
1133 def __deepcopy__(self, memo):
1134 # Hack sorting double chained task lists by task_id to avoid hitting
1135 # max_depth on deepcopy operations.
1136 sys.setrecursionlimit(5000) # TODO fix this in a better way
1138 cls = self.__class__
1139 result = cls.__new__(cls)
1140 memo[id(self)] = result
1142 shallow_copy = cls.shallow_copy_attrs + cls._base_operator_shallow_copy_attrs
1144 for k, v in self.__dict__.items():
1145 if k == "_BaseOperator__instantiated":
1146 # Don't set this until the _end_, as it changes behaviour of __setattr__
1147 continue
1148 if k not in shallow_copy:
1149 setattr(result, k, copy.deepcopy(v, memo))
1150 else:
1151 setattr(result, k, copy.copy(v))
1152 result.__instantiated = self.__instantiated
1153 return result
1155 def __getstate__(self):
1156 state = dict(self.__dict__)
1157 del state["_log"]
1159 return state
1161 def __setstate__(self, state):
1162 self.__dict__ = state
1163 self._log = logging.getLogger("airflow.task.operators")
1165 def render_template_fields(
1166 self,
1167 context: Context,
1168 jinja_env: jinja2.Environment | None = None,
1169 ) -> None:
1170 """Template all attributes listed in *self.template_fields*.
1172 This mutates the attributes in-place and is irreversible.
1174 :param context: Context dict with values to apply on content.
1175 :param jinja_env: Jinja environment to use for rendering.
1176 """
1177 if not jinja_env:
1178 jinja_env = self.get_template_env()
1179 self._do_render_template_fields(self, self.template_fields, context, jinja_env, set())
1181 @provide_session
1182 def clear(
1183 self,
1184 start_date: datetime | None = None,
1185 end_date: datetime | None = None,
1186 upstream: bool = False,
1187 downstream: bool = False,
1188 session: Session = NEW_SESSION,
1189 ):
1190 """
1191 Clears the state of task instances associated with the task, following
1192 the parameters specified.
1193 """
1194 qry = session.query(TaskInstance).filter(TaskInstance.dag_id == self.dag_id)
1196 if start_date:
1197 qry = qry.filter(TaskInstance.execution_date >= start_date)
1198 if end_date:
1199 qry = qry.filter(TaskInstance.execution_date <= end_date)
1201 tasks = [self.task_id]
1203 if upstream:
1204 tasks += [t.task_id for t in self.get_flat_relatives(upstream=True)]
1206 if downstream:
1207 tasks += [t.task_id for t in self.get_flat_relatives(upstream=False)]
1209 qry = qry.filter(TaskInstance.task_id.in_(tasks))
1210 results = qry.all()
1211 count = len(results)
1212 clear_task_instances(results, session, dag=self.dag)
1213 session.commit()
1214 return count
1216 @provide_session
1217 def get_task_instances(
1218 self,
1219 start_date: datetime | None = None,
1220 end_date: datetime | None = None,
1221 session: Session = NEW_SESSION,
1222 ) -> list[TaskInstance]:
1223 """Get task instances related to this task for a specific date range."""
1224 from airflow.models import DagRun
1226 end_date = end_date or timezone.utcnow()
1227 return (
1228 session.query(TaskInstance)
1229 .join(TaskInstance.dag_run)
1230 .filter(TaskInstance.dag_id == self.dag_id)
1231 .filter(TaskInstance.task_id == self.task_id)
1232 .filter(DagRun.execution_date >= start_date)
1233 .filter(DagRun.execution_date <= end_date)
1234 .order_by(DagRun.execution_date)
1235 .all()
1236 )
1238 @provide_session
1239 def run(
1240 self,
1241 start_date: datetime | None = None,
1242 end_date: datetime | None = None,
1243 ignore_first_depends_on_past: bool = True,
1244 ignore_ti_state: bool = False,
1245 mark_success: bool = False,
1246 test_mode: bool = False,
1247 session: Session = NEW_SESSION,
1248 ) -> None:
1249 """Run a set of task instances for a date range."""
1250 from airflow.models import DagRun
1251 from airflow.utils.types import DagRunType
1253 # Assertions for typing -- we need a dag, for this function, and when we have a DAG we are
1254 # _guaranteed_ to have start_date (else we couldn't have been added to a DAG)
1255 if TYPE_CHECKING:
1256 assert self.start_date
1258 start_date = pendulum.instance(start_date or self.start_date)
1259 end_date = pendulum.instance(end_date or self.end_date or timezone.utcnow())
1261 for info in self.dag.iter_dagrun_infos_between(start_date, end_date, align=False):
1262 ignore_depends_on_past = info.logical_date == start_date and ignore_first_depends_on_past
1263 try:
1264 dag_run = (
1265 session.query(DagRun)
1266 .filter(
1267 DagRun.dag_id == self.dag_id,
1268 DagRun.execution_date == info.logical_date,
1269 )
1270 .one()
1271 )
1272 ti = TaskInstance(self, run_id=dag_run.run_id)
1273 except NoResultFound:
1274 # This is _mostly_ only used in tests
1275 dr = DagRun(
1276 dag_id=self.dag_id,
1277 run_id=DagRun.generate_run_id(DagRunType.MANUAL, info.logical_date),
1278 run_type=DagRunType.MANUAL,
1279 execution_date=info.logical_date,
1280 data_interval=info.data_interval,
1281 )
1282 ti = TaskInstance(self, run_id=dr.run_id)
1283 ti.dag_run = dr
1284 session.add(dr)
1285 session.flush()
1287 ti.run(
1288 mark_success=mark_success,
1289 ignore_depends_on_past=ignore_depends_on_past,
1290 ignore_ti_state=ignore_ti_state,
1291 test_mode=test_mode,
1292 session=session,
1293 )
1295 def dry_run(self) -> None:
1296 """Performs dry run for the operator - just render template fields."""
1297 self.log.info("Dry run")
1298 for field in self.template_fields:
1299 try:
1300 content = getattr(self, field)
1301 except AttributeError:
1302 raise AttributeError(
1303 f"{field!r} is configured as a template field "
1304 f"but {self.task_type} does not have this attribute."
1305 )
1307 if content and isinstance(content, str):
1308 self.log.info("Rendering template for %s", field)
1309 self.log.info(content)
1311 def get_direct_relatives(self, upstream: bool = False) -> Iterable[DAGNode]:
1312 """
1313 Get list of the direct relatives to the current task, upstream or
1314 downstream.
1315 """
1316 if upstream:
1317 return self.upstream_list
1318 else:
1319 return self.downstream_list
1321 def __repr__(self):
1322 return "<Task({self.task_type}): {self.task_id}>".format(self=self)
1324 @property
1325 def operator_class(self) -> type[BaseOperator]: # type: ignore[override]
1326 return self.__class__
1328 @property
1329 def task_type(self) -> str:
1330 """@property: type of the task"""
1331 return self.__class__.__name__
1333 @property
1334 def operator_name(self) -> str:
1335 """@property: use a more friendly display name for the operator, if set"""
1336 try:
1337 return self.custom_operator_name # type: ignore
1338 except AttributeError:
1339 return self.task_type
1341 @property
1342 def roots(self) -> list[BaseOperator]:
1343 """Required by DAGNode."""
1344 return [self]
1346 @property
1347 def leaves(self) -> list[BaseOperator]:
1348 """Required by DAGNode."""
1349 return [self]
1351 @property
1352 def output(self) -> XComArg:
1353 """Returns reference to XCom pushed by current operator"""
1354 from airflow.models.xcom_arg import XComArg
1356 return XComArg(operator=self)
1358 @staticmethod
1359 def xcom_push(
1360 context: Any,
1361 key: str,
1362 value: Any,
1363 execution_date: datetime | None = None,
1364 ) -> None:
1365 """
1366 Make an XCom available for tasks to pull.
1368 :param context: Execution Context Dictionary
1369 :param key: A key for the XCom
1370 :param value: A value for the XCom. The value is pickled and stored
1371 in the database.
1372 :param execution_date: if provided, the XCom will not be visible until
1373 this date. This can be used, for example, to send a message to a
1374 task on a future date without it being immediately visible.
1375 """
1376 context["ti"].xcom_push(key=key, value=value, execution_date=execution_date)
1378 @staticmethod
1379 def xcom_pull(
1380 context: Any,
1381 task_ids: str | list[str] | None = None,
1382 dag_id: str | None = None,
1383 key: str = XCOM_RETURN_KEY,
1384 include_prior_dates: bool | None = None,
1385 ) -> Any:
1386 """
1387 Pull XComs that optionally meet certain criteria.
1389 The default value for `key` limits the search to XComs
1390 that were returned by other tasks (as opposed to those that were pushed
1391 manually). To remove this filter, pass key=None (or any desired value).
1393 If a single task_id string is provided, the result is the value of the
1394 most recent matching XCom from that task_id. If multiple task_ids are
1395 provided, a tuple of matching values is returned. None is returned
1396 whenever no matches are found.
1398 :param context: Execution Context Dictionary
1399 :param key: A key for the XCom. If provided, only XComs with matching
1400 keys will be returned. The default key is 'return_value', also
1401 available as a constant XCOM_RETURN_KEY. This key is automatically
1402 given to XComs returned by tasks (as opposed to being pushed
1403 manually). To remove the filter, pass key=None.
1404 :param task_ids: Only XComs from tasks with matching ids will be
1405 pulled. Can pass None to remove the filter.
1406 :param dag_id: If provided, only pulls XComs from this DAG.
1407 If None (default), the DAG of the calling task is used.
1408 :param include_prior_dates: If False, only XComs from the current
1409 execution_date are returned. If True, XComs from previous dates
1410 are returned as well.
1411 """
1412 return context["ti"].xcom_pull(
1413 key=key, task_ids=task_ids, dag_id=dag_id, include_prior_dates=include_prior_dates
1414 )
1416 @classmethod
1417 def get_serialized_fields(cls):
1418 """Stringified DAGs and operators contain exactly these fields."""
1419 if not cls.__serialized_fields:
1420 from airflow.models.dag import DagContext
1422 # make sure the following dummy task is not added to current active
1423 # dag in context, otherwise, it will result in
1424 # `RuntimeError: dictionary changed size during iteration`
1425 # Exception in SerializedDAG.serialize_dag() call.
1426 DagContext.push_context_managed_dag(None)
1427 cls.__serialized_fields = frozenset(
1428 vars(BaseOperator(task_id="test")).keys()
1429 - {
1430 "upstream_task_ids",
1431 "default_args",
1432 "dag",
1433 "_dag",
1434 "label",
1435 "_BaseOperator__instantiated",
1436 "_BaseOperator__init_kwargs",
1437 "_BaseOperator__from_mapped",
1438 }
1439 | { # Class level defaults need to be added to this list
1440 "start_date",
1441 "end_date",
1442 "_task_type",
1443 "_operator_name",
1444 "subdag",
1445 "ui_color",
1446 "ui_fgcolor",
1447 "template_ext",
1448 "template_fields",
1449 "template_fields_renderers",
1450 "params",
1451 }
1452 )
1453 DagContext.pop_context_managed_dag()
1455 return cls.__serialized_fields
1457 def serialize_for_task_group(self) -> tuple[DagAttributeTypes, Any]:
1458 """Required by DAGNode."""
1459 return DagAttributeTypes.OP, self.task_id
1461 @property
1462 def inherits_from_empty_operator(self):
1463 """Used to determine if an Operator is inherited from EmptyOperator"""
1464 # This looks like `isinstance(self, EmptyOperator) would work, but this also
1465 # needs to cope when `self` is a Serialized instance of a EmptyOperator or one
1466 # of its sub-classes (which don't inherit from anything but BaseOperator).
1467 return getattr(self, "_is_empty", False)
1469 def defer(
1470 self,
1471 *,
1472 trigger: BaseTrigger,
1473 method_name: str,
1474 kwargs: dict[str, Any] | None = None,
1475 timeout: timedelta | None = None,
1476 ):
1477 """
1478 Marks this Operator as being "deferred" - that is, suspending its
1479 execution until the provided trigger fires an event.
1481 This is achieved by raising a special exception (TaskDeferred)
1482 which is caught in the main _execute_task wrapper.
1483 """
1484 raise TaskDeferred(trigger=trigger, method_name=method_name, kwargs=kwargs, timeout=timeout)
1486 def unmap(self, resolve: None | dict[str, Any] | tuple[Context, Session]) -> BaseOperator:
1487 """:meta private:"""
1488 return self
1491# TODO: Deprecate for Airflow 3.0
1492Chainable = Union[DependencyMixin, Sequence[DependencyMixin]]
1495def chain(*tasks: DependencyMixin | Sequence[DependencyMixin]) -> None:
1496 r"""
1497 Given a number of tasks, builds a dependency chain.
1499 This function accepts values of BaseOperator (aka tasks), EdgeModifiers (aka Labels), XComArg, TaskGroups,
1500 or lists containing any mix of these types (or a mix in the same list). If you want to chain between two
1501 lists you must ensure they have the same length.
1503 Using classic operators/sensors:
1505 .. code-block:: python
1507 chain(t1, [t2, t3], [t4, t5], t6)
1509 is equivalent to::
1511 / -> t2 -> t4 \
1512 t1 -> t6
1513 \ -> t3 -> t5 /
1515 .. code-block:: python
1517 t1.set_downstream(t2)
1518 t1.set_downstream(t3)
1519 t2.set_downstream(t4)
1520 t3.set_downstream(t5)
1521 t4.set_downstream(t6)
1522 t5.set_downstream(t6)
1524 Using task-decorated functions aka XComArgs:
1526 .. code-block:: python
1528 chain(x1(), [x2(), x3()], [x4(), x5()], x6())
1530 is equivalent to::
1532 / -> x2 -> x4 \
1533 x1 -> x6
1534 \ -> x3 -> x5 /
1536 .. code-block:: python
1538 x1 = x1()
1539 x2 = x2()
1540 x3 = x3()
1541 x4 = x4()
1542 x5 = x5()
1543 x6 = x6()
1544 x1.set_downstream(x2)
1545 x1.set_downstream(x3)
1546 x2.set_downstream(x4)
1547 x3.set_downstream(x5)
1548 x4.set_downstream(x6)
1549 x5.set_downstream(x6)
1551 Using TaskGroups:
1553 .. code-block:: python
1555 chain(t1, task_group1, task_group2, t2)
1557 t1.set_downstream(task_group1)
1558 task_group1.set_downstream(task_group2)
1559 task_group2.set_downstream(t2)
1562 It is also possible to mix between classic operator/sensor, EdgeModifiers, XComArg, and TaskGroups:
1564 .. code-block:: python
1566 chain(t1, [Label("branch one"), Label("branch two")], [x1(), x2()], task_group1, t2())
1568 is equivalent to::
1570 / "branch one" -> x1 \
1571 t1 -> t2 -> x3
1572 \ "branch two" -> x2 /
1574 .. code-block:: python
1576 x1 = x1()
1577 x2 = x2()
1578 x3 = x3()
1579 label1 = Label("branch one")
1580 label2 = Label("branch two")
1581 t1.set_downstream(label1)
1582 label1.set_downstream(x1)
1583 t2.set_downstream(label2)
1584 label2.set_downstream(x2)
1585 x1.set_downstream(task_group1)
1586 x2.set_downstream(task_group1)
1587 task_group1.set_downstream(x3)
1589 # or
1591 x1 = x1()
1592 x2 = x2()
1593 x3 = x3()
1594 t1.set_downstream(x1, edge_modifier=Label("branch one"))
1595 t1.set_downstream(x2, edge_modifier=Label("branch two"))
1596 x1.set_downstream(task_group1)
1597 x2.set_downstream(task_group1)
1598 task_group1.set_downstream(x3)
1601 :param tasks: Individual and/or list of tasks, EdgeModifiers, XComArgs, or TaskGroups to set dependencies
1602 """
1603 for index, up_task in enumerate(tasks[:-1]):
1604 down_task = tasks[index + 1]
1605 if isinstance(up_task, DependencyMixin):
1606 up_task.set_downstream(down_task)
1607 continue
1608 if isinstance(down_task, DependencyMixin):
1609 down_task.set_upstream(up_task)
1610 continue
1611 if not isinstance(up_task, Sequence) or not isinstance(down_task, Sequence):
1612 raise TypeError(f"Chain not supported between instances of {type(up_task)} and {type(down_task)}")
1613 up_task_list = up_task
1614 down_task_list = down_task
1615 if len(up_task_list) != len(down_task_list):
1616 raise AirflowException(
1617 f"Chain not supported for different length Iterable. "
1618 f"Got {len(up_task_list)} and {len(down_task_list)}."
1619 )
1620 for up_t, down_t in zip(up_task_list, down_task_list):
1621 up_t.set_downstream(down_t)
1624def cross_downstream(
1625 from_tasks: Sequence[DependencyMixin],
1626 to_tasks: DependencyMixin | Sequence[DependencyMixin],
1627):
1628 r"""
1629 Set downstream dependencies for all tasks in from_tasks to all tasks in to_tasks.
1631 Using classic operators/sensors:
1633 .. code-block:: python
1635 cross_downstream(from_tasks=[t1, t2, t3], to_tasks=[t4, t5, t6])
1637 is equivalent to::
1639 t1 ---> t4
1640 \ /
1641 t2 -X -> t5
1642 / \
1643 t3 ---> t6
1645 .. code-block:: python
1647 t1.set_downstream(t4)
1648 t1.set_downstream(t5)
1649 t1.set_downstream(t6)
1650 t2.set_downstream(t4)
1651 t2.set_downstream(t5)
1652 t2.set_downstream(t6)
1653 t3.set_downstream(t4)
1654 t3.set_downstream(t5)
1655 t3.set_downstream(t6)
1657 Using task-decorated functions aka XComArgs:
1659 .. code-block:: python
1661 cross_downstream(from_tasks=[x1(), x2(), x3()], to_tasks=[x4(), x5(), x6()])
1663 is equivalent to::
1665 x1 ---> x4
1666 \ /
1667 x2 -X -> x5
1668 / \
1669 x3 ---> x6
1671 .. code-block:: python
1673 x1 = x1()
1674 x2 = x2()
1675 x3 = x3()
1676 x4 = x4()
1677 x5 = x5()
1678 x6 = x6()
1679 x1.set_downstream(x4)
1680 x1.set_downstream(x5)
1681 x1.set_downstream(x6)
1682 x2.set_downstream(x4)
1683 x2.set_downstream(x5)
1684 x2.set_downstream(x6)
1685 x3.set_downstream(x4)
1686 x3.set_downstream(x5)
1687 x3.set_downstream(x6)
1689 It is also possible to mix between classic operator/sensor and XComArg tasks:
1691 .. code-block:: python
1693 cross_downstream(from_tasks=[t1, x2(), t3], to_tasks=[x1(), t2, x3()])
1695 is equivalent to::
1697 t1 ---> x1
1698 \ /
1699 x2 -X -> t2
1700 / \
1701 t3 ---> x3
1703 .. code-block:: python
1705 x1 = x1()
1706 x2 = x2()
1707 x3 = x3()
1708 t1.set_downstream(x1)
1709 t1.set_downstream(t2)
1710 t1.set_downstream(x3)
1711 x2.set_downstream(x1)
1712 x2.set_downstream(t2)
1713 x2.set_downstream(x3)
1714 t3.set_downstream(x1)
1715 t3.set_downstream(t2)
1716 t3.set_downstream(x3)
1718 :param from_tasks: List of tasks or XComArgs to start from.
1719 :param to_tasks: List of tasks or XComArgs to set as downstream dependencies.
1720 """
1721 for task in from_tasks:
1722 task.set_downstream(to_tasks)
1725# pyupgrade assumes all type annotations can be lazily evaluated, but this is
1726# not the case for attrs-decorated classes, since cattrs needs to evaluate the
1727# annotation expressions at runtime, and Python before 3.9.0 does not lazily
1728# evaluate those. Putting the expression in a top-level assignment statement
1729# communicates this runtime requirement to pyupgrade.
1730BaseOperatorClassList = List[Type[BaseOperator]]
1733@attr.s(auto_attribs=True)
1734class BaseOperatorLink(metaclass=ABCMeta):
1735 """Abstract base class that defines how we get an operator link."""
1737 operators: ClassVar[BaseOperatorClassList] = []
1738 """
1739 This property will be used by Airflow Plugins to find the Operators to which you want
1740 to assign this Operator Link
1742 :return: List of Operator classes used by task for which you want to create extra link
1743 """
1745 @property
1746 @abstractmethod
1747 def name(self) -> str:
1748 """Name of the link. This will be the button name on the task UI."""
1750 @abstractmethod
1751 def get_link(self, operator: BaseOperator, *, ti_key: TaskInstanceKey) -> str:
1752 """Link to external system.
1754 Note: The old signature of this function was ``(self, operator, dttm: datetime)``. That is still
1755 supported at runtime but is deprecated.
1757 :param operator: The Airflow operator object this link is associated to.
1758 :param ti_key: TaskInstance ID to return link for.
1759 :return: link to external system
1760 """