Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/airflow/utils/context.py: 41%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

179 statements  

1# 

2# Licensed to the Apache Software Foundation (ASF) under one 

3# or more contributor license agreements. See the NOTICE file 

4# distributed with this work for additional information 

5# regarding copyright ownership. The ASF licenses this file 

6# to you under the Apache License, Version 2.0 (the 

7# "License"); you may not use this file except in compliance 

8# with the License. You may obtain a copy of the License at 

9# 

10# http://www.apache.org/licenses/LICENSE-2.0 

11# 

12# Unless required by applicable law or agreed to in writing, 

13# software distributed under the License is distributed on an 

14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 

15# KIND, either express or implied. See the License for the 

16# specific language governing permissions and limitations 

17# under the License. 

18"""Jinja2 template rendering context helper.""" 

19 

20from __future__ import annotations 

21 

22import contextlib 

23import copy 

24import functools 

25import warnings 

26from typing import ( 

27 TYPE_CHECKING, 

28 Any, 

29 Container, 

30 ItemsView, 

31 Iterator, 

32 KeysView, 

33 Mapping, 

34 MutableMapping, 

35 SupportsIndex, 

36 ValuesView, 

37) 

38 

39import attrs 

40import lazy_object_proxy 

41from sqlalchemy import select 

42 

43from airflow.datasets import Dataset, coerce_to_uri 

44from airflow.exceptions import RemovedInAirflow3Warning 

45from airflow.models.dataset import DatasetEvent, DatasetModel 

46from airflow.utils.db import LazySelectSequence 

47from airflow.utils.types import NOTSET 

48 

49if TYPE_CHECKING: 

50 from sqlalchemy.engine import Row 

51 from sqlalchemy.orm import Session 

52 from sqlalchemy.sql.expression import Select, TextClause 

53 

54 from airflow.models.baseoperator import BaseOperator 

55 

56# NOTE: Please keep this in sync with the following: 

57# * Context in airflow/utils/context.pyi. 

58# * Table in docs/apache-airflow/templates-ref.rst 

59KNOWN_CONTEXT_KEYS: set[str] = { 

60 "conf", 

61 "conn", 

62 "dag", 

63 "dag_run", 

64 "data_interval_end", 

65 "data_interval_start", 

66 "ds", 

67 "ds_nodash", 

68 "execution_date", 

69 "expanded_ti_count", 

70 "exception", 

71 "inlets", 

72 "inlet_events", 

73 "logical_date", 

74 "macros", 

75 "map_index_template", 

76 "next_ds", 

77 "next_ds_nodash", 

78 "next_execution_date", 

79 "outlets", 

80 "outlet_events", 

81 "params", 

82 "prev_data_interval_start_success", 

83 "prev_data_interval_end_success", 

84 "prev_ds", 

85 "prev_ds_nodash", 

86 "prev_execution_date", 

87 "prev_execution_date_success", 

88 "prev_start_date_success", 

89 "prev_end_date_success", 

90 "reason", 

91 "run_id", 

92 "task", 

93 "task_instance", 

94 "task_instance_key_str", 

95 "test_mode", 

96 "templates_dict", 

97 "ti", 

98 "tomorrow_ds", 

99 "tomorrow_ds_nodash", 

100 "triggering_dataset_events", 

101 "ts", 

102 "ts_nodash", 

103 "ts_nodash_with_tz", 

104 "try_number", 

105 "var", 

106 "yesterday_ds", 

107 "yesterday_ds_nodash", 

108} 

109 

110 

111class VariableAccessor: 

112 """Wrapper to access Variable values in template.""" 

113 

114 def __init__(self, *, deserialize_json: bool) -> None: 

115 self._deserialize_json = deserialize_json 

116 self.var: Any = None 

117 

118 def __getattr__(self, key: str) -> Any: 

119 from airflow.models.variable import Variable 

120 

121 self.var = Variable.get(key, deserialize_json=self._deserialize_json) 

122 return self.var 

123 

124 def __repr__(self) -> str: 

125 return str(self.var) 

126 

127 def get(self, key, default: Any = NOTSET) -> Any: 

128 from airflow.models.variable import Variable 

129 

130 if default is NOTSET: 

131 return Variable.get(key, deserialize_json=self._deserialize_json) 

132 return Variable.get(key, default, deserialize_json=self._deserialize_json) 

133 

134 

135class ConnectionAccessor: 

136 """Wrapper to access Connection entries in template.""" 

137 

138 def __init__(self) -> None: 

139 self.var: Any = None 

140 

141 def __getattr__(self, key: str) -> Any: 

142 from airflow.models.connection import Connection 

143 

144 self.var = Connection.get_connection_from_secrets(key) 

145 return self.var 

146 

147 def __repr__(self) -> str: 

148 return str(self.var) 

149 

150 def get(self, key: str, default_conn: Any = None) -> Any: 

151 from airflow.exceptions import AirflowNotFoundException 

152 from airflow.models.connection import Connection 

153 

154 try: 

155 return Connection.get_connection_from_secrets(key) 

156 except AirflowNotFoundException: 

157 return default_conn 

158 

159 

160@attrs.define() 

161class OutletEventAccessor: 

162 """Wrapper to access an outlet dataset event in template. 

163 

164 :meta private: 

165 """ 

166 

167 extra: dict[str, Any] 

168 

169 

170class OutletEventAccessors(Mapping[str, OutletEventAccessor]): 

171 """Lazy mapping of outlet dataset event accessors. 

172 

173 :meta private: 

174 """ 

175 

176 def __init__(self) -> None: 

177 self._dict: dict[str, OutletEventAccessor] = {} 

178 

179 def __iter__(self) -> Iterator[str]: 

180 return iter(self._dict) 

181 

182 def __len__(self) -> int: 

183 return len(self._dict) 

184 

185 def __getitem__(self, key: str | Dataset) -> OutletEventAccessor: 

186 if (uri := coerce_to_uri(key)) not in self._dict: 

187 self._dict[uri] = OutletEventAccessor({}) 

188 return self._dict[uri] 

189 

190 

191class LazyDatasetEventSelectSequence(LazySelectSequence[DatasetEvent]): 

192 """List-like interface to lazily access DatasetEvent rows. 

193 

194 :meta private: 

195 """ 

196 

197 @staticmethod 

198 def _rebuild_select(stmt: TextClause) -> Select: 

199 return select(DatasetEvent).from_statement(stmt) 

200 

201 @staticmethod 

202 def _process_row(row: Row) -> DatasetEvent: 

203 return row[0] 

204 

205 

206@attrs.define(init=False) 

207class InletEventsAccessors(Mapping[str, LazyDatasetEventSelectSequence]): 

208 """Lazy mapping for inlet dataset events accessors. 

209 

210 :meta private: 

211 """ 

212 

213 _inlets: list[Any] 

214 _datasets: dict[str, Dataset] 

215 _session: Session 

216 

217 def __init__(self, inlets: list, *, session: Session) -> None: 

218 self._inlets = inlets 

219 self._datasets = {inlet.uri: inlet for inlet in inlets if isinstance(inlet, Dataset)} 

220 self._session = session 

221 

222 def __iter__(self) -> Iterator[str]: 

223 return iter(self._inlets) 

224 

225 def __len__(self) -> int: 

226 return len(self._inlets) 

227 

228 def __getitem__(self, key: int | str | Dataset) -> LazyDatasetEventSelectSequence: 

229 if isinstance(key, int): # Support index access; it's easier for trivial cases. 

230 dataset = self._inlets[key] 

231 if not isinstance(dataset, Dataset): 

232 raise IndexError(key) 

233 else: 

234 dataset = self._datasets[coerce_to_uri(key)] 

235 return LazyDatasetEventSelectSequence.from_select( 

236 select(DatasetEvent).join(DatasetEvent.dataset).where(DatasetModel.uri == dataset.uri), 

237 order_by=[DatasetEvent.timestamp], 

238 session=self._session, 

239 ) 

240 

241 

242class AirflowContextDeprecationWarning(RemovedInAirflow3Warning): 

243 """Warn for usage of deprecated context variables in a task.""" 

244 

245 

246def _create_deprecation_warning(key: str, replacements: list[str]) -> RemovedInAirflow3Warning: 

247 message = f"Accessing {key!r} from the template is deprecated and will be removed in a future version." 

248 if not replacements: 

249 return AirflowContextDeprecationWarning(message) 

250 display_except_last = ", ".join(repr(r) for r in replacements[:-1]) 

251 if display_except_last: 

252 message += f" Please use {display_except_last} or {replacements[-1]!r} instead." 

253 else: 

254 message += f" Please use {replacements[-1]!r} instead." 

255 return AirflowContextDeprecationWarning(message) 

256 

257 

258class Context(MutableMapping[str, Any]): 

259 """Jinja2 template context for task rendering. 

260 

261 This is a mapping (dict-like) class that can lazily emit warnings when 

262 (and only when) deprecated context keys are accessed. 

263 """ 

264 

265 _DEPRECATION_REPLACEMENTS: dict[str, list[str]] = { 

266 "execution_date": ["data_interval_start", "logical_date"], 

267 "next_ds": ["{{ data_interval_end | ds }}"], 

268 "next_ds_nodash": ["{{ data_interval_end | ds_nodash }}"], 

269 "next_execution_date": ["data_interval_end"], 

270 "prev_ds": [], 

271 "prev_ds_nodash": [], 

272 "prev_execution_date": [], 

273 "prev_execution_date_success": ["prev_data_interval_start_success"], 

274 "tomorrow_ds": [], 

275 "tomorrow_ds_nodash": [], 

276 "yesterday_ds": [], 

277 "yesterday_ds_nodash": [], 

278 } 

279 

280 def __init__(self, context: MutableMapping[str, Any] | None = None, **kwargs: Any) -> None: 

281 self._context: MutableMapping[str, Any] = context or {} 

282 if kwargs: 

283 self._context.update(kwargs) 

284 self._deprecation_replacements = self._DEPRECATION_REPLACEMENTS.copy() 

285 

286 def __repr__(self) -> str: 

287 return repr(self._context) 

288 

289 def __reduce_ex__(self, protocol: SupportsIndex) -> tuple[Any, ...]: 

290 """Pickle the context as a dict. 

291 

292 We are intentionally going through ``__getitem__`` in this function, 

293 instead of using ``items()``, to trigger deprecation warnings. 

294 """ 

295 items = [(key, self[key]) for key in self._context] 

296 return dict, (items,) 

297 

298 def __copy__(self) -> Context: 

299 new = type(self)(copy.copy(self._context)) 

300 new._deprecation_replacements = self._deprecation_replacements.copy() 

301 return new 

302 

303 def __getitem__(self, key: str) -> Any: 

304 with contextlib.suppress(KeyError): 

305 warnings.warn( 

306 _create_deprecation_warning(key, self._deprecation_replacements[key]), 

307 stacklevel=2, 

308 ) 

309 with contextlib.suppress(KeyError): 

310 return self._context[key] 

311 raise KeyError(key) 

312 

313 def __setitem__(self, key: str, value: Any) -> None: 

314 self._deprecation_replacements.pop(key, None) 

315 self._context[key] = value 

316 

317 def __delitem__(self, key: str) -> None: 

318 self._deprecation_replacements.pop(key, None) 

319 del self._context[key] 

320 

321 def __contains__(self, key: object) -> bool: 

322 return key in self._context 

323 

324 def __iter__(self) -> Iterator[str]: 

325 return iter(self._context) 

326 

327 def __len__(self) -> int: 

328 return len(self._context) 

329 

330 def __eq__(self, other: Any) -> bool: 

331 if not isinstance(other, Context): 

332 return NotImplemented 

333 return self._context == other._context 

334 

335 def __ne__(self, other: Any) -> bool: 

336 if not isinstance(other, Context): 

337 return NotImplemented 

338 return self._context != other._context 

339 

340 def keys(self) -> KeysView[str]: 

341 return self._context.keys() 

342 

343 def items(self): 

344 return ItemsView(self._context) 

345 

346 def values(self): 

347 return ValuesView(self._context) 

348 

349 

350def context_merge(context: Context, *args: Any, **kwargs: Any) -> None: 

351 """Merge parameters into an existing context. 

352 

353 Like ``dict.update()`` , this take the same parameters, and updates 

354 ``context`` in-place. 

355 

356 This is implemented as a free function because the ``Context`` type is 

357 "faked" as a ``TypedDict`` in ``context.pyi``, which cannot have custom 

358 functions. 

359 

360 :meta private: 

361 """ 

362 context.update(*args, **kwargs) 

363 

364 

365def context_update_for_unmapped(context: Context, task: BaseOperator) -> None: 

366 """Update context after task unmapping. 

367 

368 Since ``get_template_context()`` is called before unmapping, the context 

369 contains information about the mapped task. We need to do some in-place 

370 updates to ensure the template context reflects the unmapped task instead. 

371 

372 :meta private: 

373 """ 

374 from airflow.models.param import process_params 

375 

376 context["task"] = context["ti"].task = task 

377 context["params"] = process_params(context["dag"], task, context["dag_run"], suppress_exception=False) 

378 

379 

380def context_copy_partial(source: Context, keys: Container[str]) -> Context: 

381 """Create a context by copying items under selected keys in ``source``. 

382 

383 This is implemented as a free function because the ``Context`` type is 

384 "faked" as a ``TypedDict`` in ``context.pyi``, which cannot have custom 

385 functions. 

386 

387 :meta private: 

388 """ 

389 new = Context({k: v for k, v in source._context.items() if k in keys}) 

390 new._deprecation_replacements = source._deprecation_replacements.copy() 

391 return new 

392 

393 

394def lazy_mapping_from_context(source: Context) -> Mapping[str, Any]: 

395 """Create a mapping that wraps deprecated entries in a lazy object proxy. 

396 

397 This further delays deprecation warning to until when the entry is actually 

398 used, instead of when it's accessed in the context. The result is useful for 

399 passing into a callable with ``**kwargs``, which would unpack the mapping 

400 too eagerly otherwise. 

401 

402 This is implemented as a free function because the ``Context`` type is 

403 "faked" as a ``TypedDict`` in ``context.pyi``, which cannot have custom 

404 functions. 

405 

406 :meta private: 

407 """ 

408 if not isinstance(source, Context): 

409 # Sometimes we are passed a plain dict (usually in tests, or in User's 

410 # custom operators) -- be lienent about what we accept so we don't 

411 # break anything for users. 

412 return source 

413 

414 def _deprecated_proxy_factory(k: str, v: Any) -> Any: 

415 replacements = source._deprecation_replacements[k] 

416 warnings.warn(_create_deprecation_warning(k, replacements), stacklevel=2) 

417 return v 

418 

419 def _create_value(k: str, v: Any) -> Any: 

420 if k not in source._deprecation_replacements: 

421 return v 

422 factory = functools.partial(_deprecated_proxy_factory, k, v) 

423 return lazy_object_proxy.Proxy(factory) 

424 

425 return {k: _create_value(k, v) for k, v in source._context.items()} 

426 

427 

428def context_get_outlet_events(context: Context) -> OutletEventAccessors: 

429 try: 

430 return context["outlet_events"] 

431 except KeyError: 

432 return OutletEventAccessors()