1#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18"""Provides lineage support functions."""
19
20from __future__ import annotations
21
22import logging
23from functools import wraps
24from typing import TYPE_CHECKING, Any, Callable, TypeVar, cast
25
26from airflow.configuration import conf
27from airflow.lineage.backend import LineageBackend
28from airflow.utils.session import create_session
29
30if TYPE_CHECKING:
31 from airflow.utils.context import Context
32
33PIPELINE_OUTLETS = "pipeline_outlets"
34PIPELINE_INLETS = "pipeline_inlets"
35AUTO = "auto"
36
37log = logging.getLogger(__name__)
38
39
40def get_backend() -> LineageBackend | None:
41 """Get the lineage backend if defined in the configs."""
42 clazz = conf.getimport("lineage", "backend", fallback=None)
43
44 if clazz:
45 if not issubclass(clazz, LineageBackend):
46 raise TypeError(
47 f"Your custom Lineage class `{clazz.__name__}` "
48 f"is not a subclass of `{LineageBackend.__name__}`."
49 )
50 else:
51 return clazz()
52
53 return None
54
55
56def _render_object(obj: Any, context: Context) -> dict:
57 ti = context["ti"]
58 if TYPE_CHECKING:
59 assert ti.task
60 return ti.task.render_template(obj, context)
61
62
63T = TypeVar("T", bound=Callable)
64
65
66def apply_lineage(func: T) -> T:
67 """
68 Conditionally send lineage to the backend.
69
70 Saves the lineage to XCom and if configured to do so sends it
71 to the backend.
72 """
73 _backend = get_backend()
74
75 @wraps(func)
76 def wrapper(self, context, *args, **kwargs):
77 self.log.debug("Lineage called with inlets: %s, outlets: %s", self.inlets, self.outlets)
78
79 ret_val = func(self, context, *args, **kwargs)
80
81 outlets = list(self.outlets)
82 inlets = list(self.inlets)
83
84 if outlets:
85 self.xcom_push(context, key=PIPELINE_OUTLETS, value=outlets)
86
87 if inlets:
88 self.xcom_push(context, key=PIPELINE_INLETS, value=inlets)
89
90 if _backend:
91 _backend.send_lineage(operator=self, inlets=self.inlets, outlets=self.outlets, context=context)
92
93 return ret_val
94
95 return cast(T, wrapper)
96
97
98def prepare_lineage(func: T) -> T:
99 """
100 Prepare the lineage inlets and outlets.
101
102 Inlets can be:
103
104 * "auto" -> picks up any outlets from direct upstream tasks that have outlets defined, as such that
105 if A -> B -> C and B does not have outlets but A does, these are provided as inlets.
106 * "list of task_ids" -> picks up outlets from the upstream task_ids
107 * "list of datasets" -> manually defined list of data
108
109 """
110
111 @wraps(func)
112 def wrapper(self, context, *args, **kwargs):
113 from airflow.models.abstractoperator import AbstractOperator
114
115 self.log.debug("Preparing lineage inlets and outlets")
116
117 if isinstance(self.inlets, (str, AbstractOperator)):
118 self.inlets = [self.inlets]
119
120 if self.inlets and isinstance(self.inlets, list):
121 # get task_ids that are specified as parameter and make sure they are upstream
122 task_ids = {o for o in self.inlets if isinstance(o, str)}.union(
123 op.task_id for op in self.inlets if isinstance(op, AbstractOperator)
124 ).intersection(self.get_flat_relative_ids(upstream=True))
125
126 # pick up unique direct upstream task_ids if AUTO is specified
127 if AUTO.upper() in self.inlets or AUTO.lower() in self.inlets:
128 task_ids = task_ids.union(task_ids.symmetric_difference(self.upstream_task_ids))
129
130 # Remove auto and task_ids
131 self.inlets = [i for i in self.inlets if not isinstance(i, str)]
132
133 # We manually create a session here since xcom_pull returns a
134 # LazySelectSequence proxy. If we do not pass a session, a new one
135 # will be created, but that session will not be properly closed.
136 # After we are done iterating, we can safely close this session.
137 with create_session() as session:
138 _inlets = self.xcom_pull(
139 context, task_ids=task_ids, dag_id=self.dag_id, key=PIPELINE_OUTLETS, session=session
140 )
141 self.inlets.extend(i for it in _inlets for i in it)
142
143 elif self.inlets:
144 raise AttributeError("inlets is not a list, operator, string or attr annotated object")
145
146 if not isinstance(self.outlets, list):
147 self.outlets = [self.outlets]
148
149 # render inlets and outlets
150 self.inlets = [_render_object(i, context) for i in self.inlets]
151
152 self.outlets = [_render_object(i, context) for i in self.outlets]
153
154 self.log.debug("inlets: %s, outlets: %s", self.inlets, self.outlets)
155
156 return func(self, context, *args, **kwargs)
157
158 return cast(T, wrapper)