Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/airflow/utils/operator_helpers.py: 23%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

99 statements  

1# 

2# Licensed to the Apache Software Foundation (ASF) under one 

3# or more contributor license agreements. See the NOTICE file 

4# distributed with this work for additional information 

5# regarding copyright ownership. The ASF licenses this file 

6# to you under the Apache License, Version 2.0 (the 

7# "License"); you may not use this file except in compliance 

8# with the License. You may obtain a copy of the License at 

9# 

10# http://www.apache.org/licenses/LICENSE-2.0 

11# 

12# Unless required by applicable law or agreed to in writing, 

13# software distributed under the License is distributed on an 

14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 

15# KIND, either express or implied. See the License for the 

16# specific language governing permissions and limitations 

17# under the License. 

18from __future__ import annotations 

19 

20import logging 

21from datetime import datetime 

22from typing import TYPE_CHECKING, Any, Callable, Collection, Mapping, TypeVar 

23 

24from airflow import settings 

25from airflow.utils.context import Context, lazy_mapping_from_context 

26 

27if TYPE_CHECKING: 

28 from airflow.utils.context import OutletEventAccessors 

29 

30R = TypeVar("R") 

31 

32DEFAULT_FORMAT_PREFIX = "airflow.ctx." 

33ENV_VAR_FORMAT_PREFIX = "AIRFLOW_CTX_" 

34 

35AIRFLOW_VAR_NAME_FORMAT_MAPPING = { 

36 "AIRFLOW_CONTEXT_DAG_ID": { 

37 "default": f"{DEFAULT_FORMAT_PREFIX}dag_id", 

38 "env_var_format": f"{ENV_VAR_FORMAT_PREFIX}DAG_ID", 

39 }, 

40 "AIRFLOW_CONTEXT_TASK_ID": { 

41 "default": f"{DEFAULT_FORMAT_PREFIX}task_id", 

42 "env_var_format": f"{ENV_VAR_FORMAT_PREFIX}TASK_ID", 

43 }, 

44 "AIRFLOW_CONTEXT_EXECUTION_DATE": { 

45 "default": f"{DEFAULT_FORMAT_PREFIX}execution_date", 

46 "env_var_format": f"{ENV_VAR_FORMAT_PREFIX}EXECUTION_DATE", 

47 }, 

48 "AIRFLOW_CONTEXT_TRY_NUMBER": { 

49 "default": f"{DEFAULT_FORMAT_PREFIX}try_number", 

50 "env_var_format": f"{ENV_VAR_FORMAT_PREFIX}TRY_NUMBER", 

51 }, 

52 "AIRFLOW_CONTEXT_DAG_RUN_ID": { 

53 "default": f"{DEFAULT_FORMAT_PREFIX}dag_run_id", 

54 "env_var_format": f"{ENV_VAR_FORMAT_PREFIX}DAG_RUN_ID", 

55 }, 

56 "AIRFLOW_CONTEXT_DAG_OWNER": { 

57 "default": f"{DEFAULT_FORMAT_PREFIX}dag_owner", 

58 "env_var_format": f"{ENV_VAR_FORMAT_PREFIX}DAG_OWNER", 

59 }, 

60 "AIRFLOW_CONTEXT_DAG_EMAIL": { 

61 "default": f"{DEFAULT_FORMAT_PREFIX}dag_email", 

62 "env_var_format": f"{ENV_VAR_FORMAT_PREFIX}DAG_EMAIL", 

63 }, 

64} 

65 

66 

67def context_to_airflow_vars(context: Mapping[str, Any], in_env_var_format: bool = False) -> dict[str, str]: 

68 """ 

69 Return values used to externally reconstruct relations between dags, dag_runs, tasks and task_instances. 

70 

71 Given a context, this function provides a dictionary of values that can be used to 

72 externally reconstruct relations between dags, dag_runs, tasks and task_instances. 

73 Default to abc.def.ghi format and can be made to ABC_DEF_GHI format if 

74 in_env_var_format is set to True. 

75 

76 :param context: The context for the task_instance of interest. 

77 :param in_env_var_format: If returned vars should be in ABC_DEF_GHI format. 

78 :return: task_instance context as dict. 

79 """ 

80 params = {} 

81 if in_env_var_format: 

82 name_format = "env_var_format" 

83 else: 

84 name_format = "default" 

85 

86 task = context.get("task") 

87 task_instance = context.get("task_instance") 

88 dag_run = context.get("dag_run") 

89 

90 ops = [ 

91 (task, "email", "AIRFLOW_CONTEXT_DAG_EMAIL"), 

92 (task, "owner", "AIRFLOW_CONTEXT_DAG_OWNER"), 

93 (task_instance, "dag_id", "AIRFLOW_CONTEXT_DAG_ID"), 

94 (task_instance, "task_id", "AIRFLOW_CONTEXT_TASK_ID"), 

95 (task_instance, "execution_date", "AIRFLOW_CONTEXT_EXECUTION_DATE"), 

96 (task_instance, "try_number", "AIRFLOW_CONTEXT_TRY_NUMBER"), 

97 (dag_run, "run_id", "AIRFLOW_CONTEXT_DAG_RUN_ID"), 

98 ] 

99 

100 context_params = settings.get_airflow_context_vars(context) 

101 for key, value in context_params.items(): 

102 if not isinstance(key, str): 

103 raise TypeError(f"key <{key}> must be string") 

104 if not isinstance(value, str): 

105 raise TypeError(f"value of key <{key}> must be string, not {type(value)}") 

106 

107 if in_env_var_format: 

108 if not key.startswith(ENV_VAR_FORMAT_PREFIX): 

109 key = ENV_VAR_FORMAT_PREFIX + key.upper() 

110 else: 

111 if not key.startswith(DEFAULT_FORMAT_PREFIX): 

112 key = DEFAULT_FORMAT_PREFIX + key 

113 params[key] = value 

114 

115 for subject, attr, mapping_key in ops: 

116 _attr = getattr(subject, attr, None) 

117 if subject and _attr: 

118 mapping_value = AIRFLOW_VAR_NAME_FORMAT_MAPPING[mapping_key][name_format] 

119 if isinstance(_attr, str): 

120 params[mapping_value] = _attr 

121 elif isinstance(_attr, datetime): 

122 params[mapping_value] = _attr.isoformat() 

123 elif isinstance(_attr, list): 

124 # os env variable value needs to be string 

125 params[mapping_value] = ",".join(_attr) 

126 else: 

127 params[mapping_value] = str(_attr) 

128 

129 return params 

130 

131 

132class KeywordParameters: 

133 """Wrapper representing ``**kwargs`` to a callable. 

134 

135 The actual ``kwargs`` can be obtained by calling either ``unpacking()`` or 

136 ``serializing()``. They behave almost the same and are only different if 

137 the containing ``kwargs`` is an Airflow Context object, and the calling 

138 function uses ``**kwargs`` in the argument list. 

139 

140 In this particular case, ``unpacking()`` uses ``lazy-object-proxy`` to 

141 prevent the Context from emitting deprecation warnings too eagerly when it's 

142 unpacked by ``**``. ``serializing()`` does not do this, and will allow the 

143 warnings to be emitted eagerly, which is useful when you want to dump the 

144 content and use it somewhere else without needing ``lazy-object-proxy``. 

145 """ 

146 

147 def __init__(self, kwargs: Mapping[str, Any], *, wildcard: bool) -> None: 

148 self._kwargs = kwargs 

149 self._wildcard = wildcard 

150 

151 @classmethod 

152 def determine( 

153 cls, 

154 func: Callable[..., Any], 

155 args: Collection[Any], 

156 kwargs: Mapping[str, Any], 

157 ) -> KeywordParameters: 

158 import inspect 

159 import itertools 

160 

161 signature = inspect.signature(func) 

162 has_wildcard_kwargs = any(p.kind == p.VAR_KEYWORD for p in signature.parameters.values()) 

163 

164 for name in itertools.islice(signature.parameters.keys(), len(args)): 

165 # Check if args conflict with names in kwargs. 

166 if name in kwargs: 

167 raise ValueError(f"The key {name!r} in args is a part of kwargs and therefore reserved.") 

168 

169 if has_wildcard_kwargs: 

170 # If the callable has a **kwargs argument, it's ready to accept all the kwargs. 

171 return cls(kwargs, wildcard=True) 

172 

173 # If the callable has no **kwargs argument, it only wants the arguments it requested. 

174 kwargs = {key: kwargs[key] for key in signature.parameters if key in kwargs} 

175 return cls(kwargs, wildcard=False) 

176 

177 def unpacking(self) -> Mapping[str, Any]: 

178 """Dump the kwargs mapping to unpack with ``**`` in a function call.""" 

179 if self._wildcard and isinstance(self._kwargs, Context): # type: ignore[misc] 

180 return lazy_mapping_from_context(self._kwargs) 

181 return self._kwargs 

182 

183 def serializing(self) -> Mapping[str, Any]: 

184 """Dump the kwargs mapping for serialization purposes.""" 

185 return self._kwargs 

186 

187 

188def determine_kwargs( 

189 func: Callable[..., Any], 

190 args: Collection[Any], 

191 kwargs: Mapping[str, Any], 

192) -> Mapping[str, Any]: 

193 """ 

194 Inspect the signature of a callable to determine which kwargs need to be passed to the callable. 

195 

196 :param func: The callable that you want to invoke 

197 :param args: The positional arguments that need to be passed to the callable, so we know how many to skip. 

198 :param kwargs: The keyword arguments that need to be filtered before passing to the callable. 

199 :return: A dictionary which contains the keyword arguments that are compatible with the callable. 

200 """ 

201 return KeywordParameters.determine(func, args, kwargs).unpacking() 

202 

203 

204def make_kwargs_callable(func: Callable[..., R]) -> Callable[..., R]: 

205 """ 

206 Create a new callable that only forwards necessary arguments from any provided input. 

207 

208 Make a new callable that can accept any number of positional or keyword arguments 

209 but only forwards those required by the given callable func. 

210 """ 

211 import functools 

212 

213 @functools.wraps(func) 

214 def kwargs_func(*args, **kwargs): 

215 kwargs = determine_kwargs(func, args, kwargs) 

216 return func(*args, **kwargs) 

217 

218 return kwargs_func 

219 

220 

221class ExecutionCallableRunner: 

222 """Run an execution callable against a task context and given arguments. 

223 

224 If the callable is a simple function, this simply calls it with the supplied 

225 arguments (including the context). If the callable is a generator function, 

226 the generator is exhausted here, with the yielded values getting fed back 

227 into the task context automatically for execution. 

228 

229 :meta private: 

230 """ 

231 

232 def __init__( 

233 self, 

234 func: Callable, 

235 outlet_events: OutletEventAccessors, 

236 *, 

237 logger: logging.Logger | None, 

238 ) -> None: 

239 self.func = func 

240 self.outlet_events = outlet_events 

241 self.logger = logger or logging.getLogger(__name__) 

242 

243 def run(self, *args, **kwargs) -> Any: 

244 import inspect 

245 

246 from airflow.datasets.metadata import Metadata 

247 from airflow.utils.types import NOTSET 

248 

249 if not inspect.isgeneratorfunction(self.func): 

250 return self.func(*args, **kwargs) 

251 

252 result: Any = NOTSET 

253 

254 def _run(): 

255 nonlocal result 

256 result = yield from self.func(*args, **kwargs) 

257 

258 for metadata in _run(): 

259 if isinstance(metadata, Metadata): 

260 self.outlet_events[metadata.uri].extra.update(metadata.extra) 

261 continue 

262 self.logger.warning("Ignoring unknown data of %r received from task", type(metadata)) 

263 if self.logger.isEnabledFor(logging.DEBUG): 

264 self.logger.debug("Full yielded value: %r", metadata) 

265 

266 return result