Coverage for /pythoncovmergedfiles/medio/medio/src/airflow/build/lib/airflow/lineage/__init__.py: 44%

64 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-06-07 06:35 +0000

1# 

2# Licensed to the Apache Software Foundation (ASF) under one 

3# or more contributor license agreements. See the NOTICE file 

4# distributed with this work for additional information 

5# regarding copyright ownership. The ASF licenses this file 

6# to you under the Apache License, Version 2.0 (the 

7# "License"); you may not use this file except in compliance 

8# with the License. You may obtain a copy of the License at 

9# 

10# http://www.apache.org/licenses/LICENSE-2.0 

11# 

12# Unless required by applicable law or agreed to in writing, 

13# software distributed under the License is distributed on an 

14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 

15# KIND, either express or implied. See the License for the 

16# specific language governing permissions and limitations 

17# under the License. 

18"""Provides lineage support functions.""" 

19from __future__ import annotations 

20 

21import itertools 

22import logging 

23from functools import wraps 

24from typing import TYPE_CHECKING, Any, Callable, TypeVar, cast 

25 

26from airflow.configuration import conf 

27from airflow.lineage.backend import LineageBackend 

28 

29if TYPE_CHECKING: 

30 from airflow.utils.context import Context 

31 

32 

33PIPELINE_OUTLETS = "pipeline_outlets" 

34PIPELINE_INLETS = "pipeline_inlets" 

35AUTO = "auto" 

36 

37log = logging.getLogger(__name__) 

38 

39 

40def get_backend() -> LineageBackend | None: 

41 """Gets the lineage backend if defined in the configs.""" 

42 clazz = conf.getimport("lineage", "backend", fallback=None) 

43 

44 if clazz: 

45 if not issubclass(clazz, LineageBackend): 

46 raise TypeError( 

47 f"Your custom Lineage class `{clazz.__name__}` " 

48 f"is not a subclass of `{LineageBackend.__name__}`." 

49 ) 

50 else: 

51 return clazz() 

52 

53 return None 

54 

55 

56def _render_object(obj: Any, context: Context) -> dict: 

57 return context["ti"].task.render_template(obj, context) 

58 

59 

60T = TypeVar("T", bound=Callable) 

61 

62 

63def apply_lineage(func: T) -> T: 

64 """ 

65 Conditionally send lineage to the backend. 

66 

67 Saves the lineage to XCom and if configured to do so sends it 

68 to the backend. 

69 """ 

70 _backend = get_backend() 

71 

72 @wraps(func) 

73 def wrapper(self, context, *args, **kwargs): 

74 

75 self.log.debug("Lineage called with inlets: %s, outlets: %s", self.inlets, self.outlets) 

76 

77 ret_val = func(self, context, *args, **kwargs) 

78 

79 outlets = list(self.outlets) 

80 inlets = list(self.inlets) 

81 

82 if outlets: 

83 self.xcom_push( 

84 context, key=PIPELINE_OUTLETS, value=outlets, execution_date=context["ti"].execution_date 

85 ) 

86 

87 if inlets: 

88 self.xcom_push( 

89 context, key=PIPELINE_INLETS, value=inlets, execution_date=context["ti"].execution_date 

90 ) 

91 

92 if _backend: 

93 _backend.send_lineage(operator=self, inlets=self.inlets, outlets=self.outlets, context=context) 

94 

95 return ret_val 

96 

97 return cast(T, wrapper) 

98 

99 

100def prepare_lineage(func: T) -> T: 

101 """ 

102 Prepares the lineage inlets and outlets. 

103 

104 Inlets can be: 

105 

106 * "auto" -> picks up any outlets from direct upstream tasks that have outlets defined, as such that 

107 if A -> B -> C and B does not have outlets but A does, these are provided as inlets. 

108 * "list of task_ids" -> picks up outlets from the upstream task_ids 

109 * "list of datasets" -> manually defined list of data 

110 

111 """ 

112 

113 @wraps(func) 

114 def wrapper(self, context, *args, **kwargs): 

115 from airflow.models.abstractoperator import AbstractOperator 

116 

117 self.log.debug("Preparing lineage inlets and outlets") 

118 

119 if isinstance(self.inlets, (str, AbstractOperator)): 

120 self.inlets = [self.inlets] 

121 

122 if self.inlets and isinstance(self.inlets, list): 

123 # get task_ids that are specified as parameter and make sure they are upstream 

124 task_ids = ( 

125 {o for o in self.inlets if isinstance(o, str)} 

126 .union(op.task_id for op in self.inlets if isinstance(op, AbstractOperator)) 

127 .intersection(self.get_flat_relative_ids(upstream=True)) 

128 ) 

129 

130 # pick up unique direct upstream task_ids if AUTO is specified 

131 if AUTO.upper() in self.inlets or AUTO.lower() in self.inlets: 

132 task_ids = task_ids.union(task_ids.symmetric_difference(self.upstream_task_ids)) 

133 

134 # Remove auto and task_ids 

135 self.inlets = [i for i in self.inlets if not isinstance(i, str)] 

136 _inlets = self.xcom_pull(context, task_ids=task_ids, dag_id=self.dag_id, key=PIPELINE_OUTLETS) 

137 

138 # re-instantiate the obtained inlets 

139 # xcom_pull returns a list of items for each given task_id 

140 _inlets = [item for item in itertools.chain.from_iterable(_inlets)] 

141 

142 self.inlets.extend(_inlets) 

143 

144 elif self.inlets: 

145 raise AttributeError("inlets is not a list, operator, string or attr annotated object") 

146 

147 if not isinstance(self.outlets, list): 

148 self.outlets = [self.outlets] 

149 

150 # render inlets and outlets 

151 self.inlets = [_render_object(i, context) for i in self.inlets] 

152 

153 self.outlets = [_render_object(i, context) for i in self.outlets] 

154 

155 self.log.debug("inlets: %s, outlets: %s", self.inlets, self.outlets) 

156 

157 return func(self, context, *args, **kwargs) 

158 

159 return cast(T, wrapper)