Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/airflow/utils/context.py: 36%

125 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-06-07 06:35 +0000

1# 

2# Licensed to the Apache Software Foundation (ASF) under one 

3# or more contributor license agreements. See the NOTICE file 

4# distributed with this work for additional information 

5# regarding copyright ownership. The ASF licenses this file 

6# to you under the Apache License, Version 2.0 (the 

7# "License"); you may not use this file except in compliance 

8# with the License. You may obtain a copy of the License at 

9# 

10# http://www.apache.org/licenses/LICENSE-2.0 

11# 

12# Unless required by applicable law or agreed to in writing, 

13# software distributed under the License is distributed on an 

14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 

15# KIND, either express or implied. See the License for the 

16# specific language governing permissions and limitations 

17# under the License. 

18"""Jinja2 template rendering context helper.""" 

19from __future__ import annotations 

20 

21import contextlib 

22import copy 

23import functools 

24import warnings 

25from typing import ( 

26 TYPE_CHECKING, 

27 Any, 

28 Container, 

29 ItemsView, 

30 Iterator, 

31 KeysView, 

32 Mapping, 

33 MutableMapping, 

34 SupportsIndex, 

35 ValuesView, 

36) 

37 

38import lazy_object_proxy 

39 

40from airflow.exceptions import RemovedInAirflow3Warning 

41from airflow.utils.types import NOTSET 

42 

43if TYPE_CHECKING: 

44 from airflow.models.baseoperator import BaseOperator 

45 

46# NOTE: Please keep this in sync with Context in airflow/utils/context.pyi. 

47KNOWN_CONTEXT_KEYS = { 

48 "conf", 

49 "conn", 

50 "dag", 

51 "dag_run", 

52 "data_interval_end", 

53 "data_interval_start", 

54 "ds", 

55 "ds_nodash", 

56 "execution_date", 

57 "expanded_ti_count", 

58 "exception", 

59 "inlets", 

60 "logical_date", 

61 "macros", 

62 "next_ds", 

63 "next_ds_nodash", 

64 "next_execution_date", 

65 "outlets", 

66 "params", 

67 "prev_data_interval_start_success", 

68 "prev_data_interval_end_success", 

69 "prev_ds", 

70 "prev_ds_nodash", 

71 "prev_execution_date", 

72 "prev_execution_date_success", 

73 "prev_start_date_success", 

74 "run_id", 

75 "task", 

76 "task_instance", 

77 "task_instance_key_str", 

78 "test_mode", 

79 "templates_dict", 

80 "ti", 

81 "tomorrow_ds", 

82 "tomorrow_ds_nodash", 

83 "triggering_dataset_events", 

84 "ts", 

85 "ts_nodash", 

86 "ts_nodash_with_tz", 

87 "try_number", 

88 "var", 

89 "yesterday_ds", 

90 "yesterday_ds_nodash", 

91} 

92 

93 

94class VariableAccessor: 

95 """Wrapper to access Variable values in template.""" 

96 

97 def __init__(self, *, deserialize_json: bool) -> None: 

98 self._deserialize_json = deserialize_json 

99 self.var: Any = None 

100 

101 def __getattr__(self, key: str) -> Any: 

102 from airflow.models.variable import Variable 

103 

104 self.var = Variable.get(key, deserialize_json=self._deserialize_json) 

105 return self.var 

106 

107 def __repr__(self) -> str: 

108 return str(self.var) 

109 

110 def get(self, key, default: Any = NOTSET) -> Any: 

111 from airflow.models.variable import Variable 

112 

113 if default is NOTSET: 

114 return Variable.get(key, deserialize_json=self._deserialize_json) 

115 return Variable.get(key, default, deserialize_json=self._deserialize_json) 

116 

117 

118class ConnectionAccessor: 

119 """Wrapper to access Connection entries in template.""" 

120 

121 def __init__(self) -> None: 

122 self.var: Any = None 

123 

124 def __getattr__(self, key: str) -> Any: 

125 from airflow.models.connection import Connection 

126 

127 self.var = Connection.get_connection_from_secrets(key) 

128 return self.var 

129 

130 def __repr__(self) -> str: 

131 return str(self.var) 

132 

133 def get(self, key: str, default_conn: Any = None) -> Any: 

134 from airflow.exceptions import AirflowNotFoundException 

135 from airflow.models.connection import Connection 

136 

137 try: 

138 return Connection.get_connection_from_secrets(key) 

139 except AirflowNotFoundException: 

140 return default_conn 

141 

142 

143class AirflowContextDeprecationWarning(RemovedInAirflow3Warning): 

144 """Warn for usage of deprecated context variables in a task.""" 

145 

146 

147def _create_deprecation_warning(key: str, replacements: list[str]) -> RemovedInAirflow3Warning: 

148 message = f"Accessing {key!r} from the template is deprecated and will be removed in a future version." 

149 if not replacements: 

150 return AirflowContextDeprecationWarning(message) 

151 display_except_last = ", ".join(repr(r) for r in replacements[:-1]) 

152 if display_except_last: 

153 message += f" Please use {display_except_last} or {replacements[-1]!r} instead." 

154 else: 

155 message += f" Please use {replacements[-1]!r} instead." 

156 return AirflowContextDeprecationWarning(message) 

157 

158 

159class Context(MutableMapping[str, Any]): 

160 """Jinja2 template context for task rendering. 

161 

162 This is a mapping (dict-like) class that can lazily emit warnings when 

163 (and only when) deprecated context keys are accessed. 

164 """ 

165 

166 _DEPRECATION_REPLACEMENTS: dict[str, list[str]] = { 

167 "execution_date": ["data_interval_start", "logical_date"], 

168 "next_ds": ["{{ data_interval_end | ds }}"], 

169 "next_ds_nodash": ["{{ data_interval_end | ds_nodash }}"], 

170 "next_execution_date": ["data_interval_end"], 

171 "prev_ds": [], 

172 "prev_ds_nodash": [], 

173 "prev_execution_date": [], 

174 "prev_execution_date_success": ["prev_data_interval_start_success"], 

175 "tomorrow_ds": [], 

176 "tomorrow_ds_nodash": [], 

177 "yesterday_ds": [], 

178 "yesterday_ds_nodash": [], 

179 } 

180 

181 def __init__(self, context: MutableMapping[str, Any] | None = None, **kwargs: Any) -> None: 

182 self._context: MutableMapping[str, Any] = context or {} 

183 if kwargs: 

184 self._context.update(kwargs) 

185 self._deprecation_replacements = self._DEPRECATION_REPLACEMENTS.copy() 

186 

187 def __repr__(self) -> str: 

188 return repr(self._context) 

189 

190 def __reduce_ex__(self, protocol: SupportsIndex) -> tuple[Any, ...]: 

191 """Pickle the context as a dict. 

192 

193 We are intentionally going through ``__getitem__`` in this function, 

194 instead of using ``items()``, to trigger deprecation warnings. 

195 """ 

196 items = [(key, self[key]) for key in self._context] 

197 return dict, (items,) 

198 

199 def __copy__(self) -> Context: 

200 new = type(self)(copy.copy(self._context)) 

201 new._deprecation_replacements = self._deprecation_replacements.copy() 

202 return new 

203 

204 def __getitem__(self, key: str) -> Any: 

205 with contextlib.suppress(KeyError): 

206 warnings.warn(_create_deprecation_warning(key, self._deprecation_replacements[key])) 

207 with contextlib.suppress(KeyError): 

208 return self._context[key] 

209 raise KeyError(key) 

210 

211 def __setitem__(self, key: str, value: Any) -> None: 

212 self._deprecation_replacements.pop(key, None) 

213 self._context[key] = value 

214 

215 def __delitem__(self, key: str) -> None: 

216 self._deprecation_replacements.pop(key, None) 

217 del self._context[key] 

218 

219 def __contains__(self, key: object) -> bool: 

220 return key in self._context 

221 

222 def __iter__(self) -> Iterator[str]: 

223 return iter(self._context) 

224 

225 def __len__(self) -> int: 

226 return len(self._context) 

227 

228 def __eq__(self, other: Any) -> bool: 

229 if not isinstance(other, Context): 

230 return NotImplemented 

231 return self._context == other._context 

232 

233 def __ne__(self, other: Any) -> bool: 

234 if not isinstance(other, Context): 

235 return NotImplemented 

236 return self._context != other._context 

237 

238 def keys(self) -> KeysView[str]: 

239 return self._context.keys() 

240 

241 def items(self): 

242 return ItemsView(self._context) 

243 

244 def values(self): 

245 return ValuesView(self._context) 

246 

247 

248def context_merge(context: Context, *args: Any, **kwargs: Any) -> None: 

249 """Merge parameters into an existing context. 

250 

251 Like ``dict.update()`` , this take the same parameters, and updates 

252 ``context`` in-place. 

253 

254 This is implemented as a free function because the ``Context`` type is 

255 "faked" as a ``TypedDict`` in ``context.pyi``, which cannot have custom 

256 functions. 

257 

258 :meta private: 

259 """ 

260 context.update(*args, **kwargs) 

261 

262 

263def context_update_for_unmapped(context: Context, task: BaseOperator) -> None: 

264 """Update context after task unmapping. 

265 

266 Since ``get_template_context()`` is called before unmapping, the context 

267 contains information about the mapped task. We need to do some in-place 

268 updates to ensure the template context reflects the unmapped task instead. 

269 

270 :meta private: 

271 """ 

272 from airflow.models.param import process_params 

273 

274 context["task"] = context["ti"].task = task 

275 context["params"] = process_params(context["dag"], task, context["dag_run"], suppress_exception=False) 

276 

277 

278def context_copy_partial(source: Context, keys: Container[str]) -> Context: 

279 """Create a context by copying items under selected keys in ``source``. 

280 

281 This is implemented as a free function because the ``Context`` type is 

282 "faked" as a ``TypedDict`` in ``context.pyi``, which cannot have custom 

283 functions. 

284 

285 :meta private: 

286 """ 

287 new = Context({k: v for k, v in source._context.items() if k in keys}) 

288 new._deprecation_replacements = source._deprecation_replacements.copy() 

289 return new 

290 

291 

292def lazy_mapping_from_context(source: Context) -> Mapping[str, Any]: 

293 """Create a mapping that wraps deprecated entries in a lazy object proxy. 

294 

295 This further delays deprecation warning to until when the entry is actually 

296 used, instead of when it's accessed in the context. The result is useful for 

297 passing into a callable with ``**kwargs``, which would unpack the mapping 

298 too eagerly otherwise. 

299 

300 This is implemented as a free function because the ``Context`` type is 

301 "faked" as a ``TypedDict`` in ``context.pyi``, which cannot have custom 

302 functions. 

303 

304 :meta private: 

305 """ 

306 if not isinstance(source, Context): 

307 # Sometimes we are passed a plain dict (usually in tests, or in User's 

308 # custom operators) -- be lienent about what we accept so we don't 

309 # break anything for users. 

310 return source 

311 

312 def _deprecated_proxy_factory(k: str, v: Any) -> Any: 

313 replacements = source._deprecation_replacements[k] 

314 warnings.warn(_create_deprecation_warning(k, replacements)) 

315 return v 

316 

317 def _create_value(k: str, v: Any) -> Any: 

318 if k not in source._deprecation_replacements: 

319 return v 

320 factory = functools.partial(_deprecated_proxy_factory, k, v) 

321 return lazy_object_proxy.Proxy(factory) 

322 

323 return {k: _create_value(k, v) for k, v in source._context.items()}