Coverage for /pythoncovmergedfiles/medio/medio/src/airflow/build/lib/airflow/utils/context.py: 36%

125 statements  

« prev     ^ index     » next       coverage.py v7.0.1, created at 2022-12-25 06:11 +0000

1# 

2# Licensed to the Apache Software Foundation (ASF) under one 

3# or more contributor license agreements. See the NOTICE file 

4# distributed with this work for additional information 

5# regarding copyright ownership. The ASF licenses this file 

6# to you under the Apache License, Version 2.0 (the 

7# "License"); you may not use this file except in compliance 

8# with the License. You may obtain a copy of the License at 

9# 

10# http://www.apache.org/licenses/LICENSE-2.0 

11# 

12# Unless required by applicable law or agreed to in writing, 

13# software distributed under the License is distributed on an 

14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 

15# KIND, either express or implied. See the License for the 

16# specific language governing permissions and limitations 

17# under the License. 

18"""Jinja2 template rendering context helper.""" 

19from __future__ import annotations 

20 

21import contextlib 

22import copy 

23import functools 

24import warnings 

25from typing import ( 

26 TYPE_CHECKING, 

27 Any, 

28 Container, 

29 ItemsView, 

30 Iterator, 

31 KeysView, 

32 Mapping, 

33 MutableMapping, 

34 ValuesView, 

35) 

36 

37import lazy_object_proxy 

38 

39from airflow.exceptions import RemovedInAirflow3Warning 

40from airflow.utils.types import NOTSET 

41 

42if TYPE_CHECKING: 

43 from airflow.models.baseoperator import BaseOperator 

44 

45# NOTE: Please keep this in sync with Context in airflow/utils/context.pyi. 

46KNOWN_CONTEXT_KEYS = { 

47 "conf", 

48 "conn", 

49 "dag", 

50 "dag_run", 

51 "data_interval_end", 

52 "data_interval_start", 

53 "ds", 

54 "ds_nodash", 

55 "execution_date", 

56 "expanded_ti_count", 

57 "exception", 

58 "inlets", 

59 "logical_date", 

60 "macros", 

61 "next_ds", 

62 "next_ds_nodash", 

63 "next_execution_date", 

64 "outlets", 

65 "params", 

66 "prev_data_interval_start_success", 

67 "prev_data_interval_end_success", 

68 "prev_ds", 

69 "prev_ds_nodash", 

70 "prev_execution_date", 

71 "prev_execution_date_success", 

72 "prev_start_date_success", 

73 "run_id", 

74 "task", 

75 "task_instance", 

76 "task_instance_key_str", 

77 "test_mode", 

78 "templates_dict", 

79 "ti", 

80 "tomorrow_ds", 

81 "tomorrow_ds_nodash", 

82 "triggering_dataset_events", 

83 "ts", 

84 "ts_nodash", 

85 "ts_nodash_with_tz", 

86 "try_number", 

87 "var", 

88 "yesterday_ds", 

89 "yesterday_ds_nodash", 

90} 

91 

92 

93class VariableAccessor: 

94 """Wrapper to access Variable values in template.""" 

95 

96 def __init__(self, *, deserialize_json: bool) -> None: 

97 self._deserialize_json = deserialize_json 

98 self.var: Any = None 

99 

100 def __getattr__(self, key: str) -> Any: 

101 from airflow.models.variable import Variable 

102 

103 self.var = Variable.get(key, deserialize_json=self._deserialize_json) 

104 return self.var 

105 

106 def __repr__(self) -> str: 

107 return str(self.var) 

108 

109 def get(self, key, default: Any = NOTSET) -> Any: 

110 from airflow.models.variable import Variable 

111 

112 if default is NOTSET: 

113 return Variable.get(key, deserialize_json=self._deserialize_json) 

114 return Variable.get(key, default, deserialize_json=self._deserialize_json) 

115 

116 

117class ConnectionAccessor: 

118 """Wrapper to access Connection entries in template.""" 

119 

120 def __init__(self) -> None: 

121 self.var: Any = None 

122 

123 def __getattr__(self, key: str) -> Any: 

124 from airflow.models.connection import Connection 

125 

126 self.var = Connection.get_connection_from_secrets(key) 

127 return self.var 

128 

129 def __repr__(self) -> str: 

130 return str(self.var) 

131 

132 def get(self, key: str, default_conn: Any = None) -> Any: 

133 from airflow.exceptions import AirflowNotFoundException 

134 from airflow.models.connection import Connection 

135 

136 try: 

137 return Connection.get_connection_from_secrets(key) 

138 except AirflowNotFoundException: 

139 return default_conn 

140 

141 

142class AirflowContextDeprecationWarning(RemovedInAirflow3Warning): 

143 """Warn for usage of deprecated context variables in a task.""" 

144 

145 

146def _create_deprecation_warning(key: str, replacements: list[str]) -> RemovedInAirflow3Warning: 

147 message = f"Accessing {key!r} from the template is deprecated and will be removed in a future version." 

148 if not replacements: 

149 return AirflowContextDeprecationWarning(message) 

150 display_except_last = ", ".join(repr(r) for r in replacements[:-1]) 

151 if display_except_last: 

152 message += f" Please use {display_except_last} or {replacements[-1]!r} instead." 

153 else: 

154 message += f" Please use {replacements[-1]!r} instead." 

155 return AirflowContextDeprecationWarning(message) 

156 

157 

158class Context(MutableMapping[str, Any]): 

159 """Jinja2 template context for task rendering. 

160 

161 This is a mapping (dict-like) class that can lazily emit warnings when 

162 (and only when) deprecated context keys are accessed. 

163 """ 

164 

165 _DEPRECATION_REPLACEMENTS: dict[str, list[str]] = { 

166 "execution_date": ["data_interval_start", "logical_date"], 

167 "next_ds": ["{{ data_interval_end | ds }}"], 

168 "next_ds_nodash": ["{{ data_interval_end | ds_nodash }}"], 

169 "next_execution_date": ["data_interval_end"], 

170 "prev_ds": [], 

171 "prev_ds_nodash": [], 

172 "prev_execution_date": [], 

173 "prev_execution_date_success": ["prev_data_interval_start_success"], 

174 "tomorrow_ds": [], 

175 "tomorrow_ds_nodash": [], 

176 "yesterday_ds": [], 

177 "yesterday_ds_nodash": [], 

178 } 

179 

180 def __init__(self, context: MutableMapping[str, Any] | None = None, **kwargs: Any) -> None: 

181 self._context: MutableMapping[str, Any] = context or {} 

182 if kwargs: 

183 self._context.update(kwargs) 

184 self._deprecation_replacements = self._DEPRECATION_REPLACEMENTS.copy() 

185 

186 def __repr__(self) -> str: 

187 return repr(self._context) 

188 

189 def __reduce_ex__(self, protocol: int) -> tuple[Any, ...]: 

190 """Pickle the context as a dict. 

191 

192 We are intentionally going through ``__getitem__`` in this function, 

193 instead of using ``items()``, to trigger deprecation warnings. 

194 """ 

195 items = [(key, self[key]) for key in self._context] 

196 return dict, (items,) 

197 

198 def __copy__(self) -> Context: 

199 new = type(self)(copy.copy(self._context)) 

200 new._deprecation_replacements = self._deprecation_replacements.copy() 

201 return new 

202 

203 def __getitem__(self, key: str) -> Any: 

204 with contextlib.suppress(KeyError): 

205 warnings.warn(_create_deprecation_warning(key, self._deprecation_replacements[key])) 

206 with contextlib.suppress(KeyError): 

207 return self._context[key] 

208 raise KeyError(key) 

209 

210 def __setitem__(self, key: str, value: Any) -> None: 

211 self._deprecation_replacements.pop(key, None) 

212 self._context[key] = value 

213 

214 def __delitem__(self, key: str) -> None: 

215 self._deprecation_replacements.pop(key, None) 

216 del self._context[key] 

217 

218 def __contains__(self, key: object) -> bool: 

219 return key in self._context 

220 

221 def __iter__(self) -> Iterator[str]: 

222 return iter(self._context) 

223 

224 def __len__(self) -> int: 

225 return len(self._context) 

226 

227 def __eq__(self, other: Any) -> bool: 

228 if not isinstance(other, Context): 

229 return NotImplemented 

230 return self._context == other._context 

231 

232 def __ne__(self, other: Any) -> bool: 

233 if not isinstance(other, Context): 

234 return NotImplemented 

235 return self._context != other._context 

236 

237 def keys(self) -> KeysView[str]: 

238 return self._context.keys() 

239 

240 def items(self): 

241 return ItemsView(self._context) 

242 

243 def values(self): 

244 return ValuesView(self._context) 

245 

246 

247def context_merge(context: Context, *args: Any, **kwargs: Any) -> None: 

248 """Merge parameters into an existing context. 

249 

250 Like ``dict.update()`` , this take the same parameters, and updates 

251 ``context`` in-place. 

252 

253 This is implemented as a free function because the ``Context`` type is 

254 "faked" as a ``TypedDict`` in ``context.pyi``, which cannot have custom 

255 functions. 

256 

257 :meta private: 

258 """ 

259 context.update(*args, **kwargs) 

260 

261 

262def context_update_for_unmapped(context: Context, task: BaseOperator) -> None: 

263 """Update context after task unmapping. 

264 

265 Since ``get_template_context()`` is called before unmapping, the context 

266 contains information about the mapped task. We need to do some in-place 

267 updates to ensure the template context reflects the unmapped task instead. 

268 

269 :meta private: 

270 """ 

271 from airflow.models.param import process_params 

272 

273 context["task"] = context["ti"].task = task 

274 context["params"] = process_params(context["dag"], task, context["dag_run"], suppress_exception=False) 

275 

276 

277def context_copy_partial(source: Context, keys: Container[str]) -> Context: 

278 """Create a context by copying items under selected keys in ``source``. 

279 

280 This is implemented as a free function because the ``Context`` type is 

281 "faked" as a ``TypedDict`` in ``context.pyi``, which cannot have custom 

282 functions. 

283 

284 :meta private: 

285 """ 

286 new = Context({k: v for k, v in source._context.items() if k in keys}) 

287 new._deprecation_replacements = source._deprecation_replacements.copy() 

288 return new 

289 

290 

291def lazy_mapping_from_context(source: Context) -> Mapping[str, Any]: 

292 """Create a mapping that wraps deprecated entries in a lazy object proxy. 

293 

294 This further delays deprecation warning to until when the entry is actually 

295 used, instead of when it's accessed in the context. The result is useful for 

296 passing into a callable with ``**kwargs``, which would unpack the mapping 

297 too eagerly otherwise. 

298 

299 This is implemented as a free function because the ``Context`` type is 

300 "faked" as a ``TypedDict`` in ``context.pyi``, which cannot have custom 

301 functions. 

302 

303 :meta private: 

304 """ 

305 if not isinstance(source, Context): 

306 # Sometimes we are passed a plain dict (usually in tests, or in User's 

307 # custom operators) -- be lienent about what we accept so we don't 

308 # break anything for users. 

309 return source 

310 

311 def _deprecated_proxy_factory(k: str, v: Any) -> Any: 

312 replacements = source._deprecation_replacements[k] 

313 warnings.warn(_create_deprecation_warning(k, replacements)) 

314 return v 

315 

316 def _create_value(k: str, v: Any) -> Any: 

317 if k not in source._deprecation_replacements: 

318 return v 

319 factory = functools.partial(_deprecated_proxy_factory, k, v) 

320 return lazy_object_proxy.Proxy(factory) 

321 

322 return {k: _create_value(k, v) for k, v in source._context.items()}