1#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18"""Jinja2 template rendering context helper."""
19
20from __future__ import annotations
21
22import contextlib
23import copy
24import functools
25import warnings
26from typing import (
27 TYPE_CHECKING,
28 Any,
29 Container,
30 ItemsView,
31 Iterator,
32 KeysView,
33 Mapping,
34 MutableMapping,
35 SupportsIndex,
36 ValuesView,
37)
38
39import attrs
40import lazy_object_proxy
41from sqlalchemy import select
42
43from airflow.datasets import Dataset, coerce_to_uri
44from airflow.exceptions import RemovedInAirflow3Warning
45from airflow.models.dataset import DatasetEvent, DatasetModel
46from airflow.utils.db import LazySelectSequence
47from airflow.utils.types import NOTSET
48
49if TYPE_CHECKING:
50 from sqlalchemy.engine import Row
51 from sqlalchemy.orm import Session
52 from sqlalchemy.sql.expression import Select, TextClause
53
54 from airflow.models.baseoperator import BaseOperator
55
56# NOTE: Please keep this in sync with the following:
57# * Context in airflow/utils/context.pyi.
58# * Table in docs/apache-airflow/templates-ref.rst
59KNOWN_CONTEXT_KEYS: set[str] = {
60 "conf",
61 "conn",
62 "dag",
63 "dag_run",
64 "data_interval_end",
65 "data_interval_start",
66 "ds",
67 "ds_nodash",
68 "execution_date",
69 "expanded_ti_count",
70 "exception",
71 "inlets",
72 "inlet_events",
73 "logical_date",
74 "macros",
75 "map_index_template",
76 "next_ds",
77 "next_ds_nodash",
78 "next_execution_date",
79 "outlets",
80 "outlet_events",
81 "params",
82 "prev_data_interval_start_success",
83 "prev_data_interval_end_success",
84 "prev_ds",
85 "prev_ds_nodash",
86 "prev_execution_date",
87 "prev_execution_date_success",
88 "prev_start_date_success",
89 "prev_end_date_success",
90 "reason",
91 "run_id",
92 "task",
93 "task_instance",
94 "task_instance_key_str",
95 "test_mode",
96 "templates_dict",
97 "ti",
98 "tomorrow_ds",
99 "tomorrow_ds_nodash",
100 "triggering_dataset_events",
101 "ts",
102 "ts_nodash",
103 "ts_nodash_with_tz",
104 "try_number",
105 "var",
106 "yesterday_ds",
107 "yesterday_ds_nodash",
108}
109
110
111class VariableAccessor:
112 """Wrapper to access Variable values in template."""
113
114 def __init__(self, *, deserialize_json: bool) -> None:
115 self._deserialize_json = deserialize_json
116 self.var: Any = None
117
118 def __getattr__(self, key: str) -> Any:
119 from airflow.models.variable import Variable
120
121 self.var = Variable.get(key, deserialize_json=self._deserialize_json)
122 return self.var
123
124 def __repr__(self) -> str:
125 return str(self.var)
126
127 def get(self, key, default: Any = NOTSET) -> Any:
128 from airflow.models.variable import Variable
129
130 if default is NOTSET:
131 return Variable.get(key, deserialize_json=self._deserialize_json)
132 return Variable.get(key, default, deserialize_json=self._deserialize_json)
133
134
135class ConnectionAccessor:
136 """Wrapper to access Connection entries in template."""
137
138 def __init__(self) -> None:
139 self.var: Any = None
140
141 def __getattr__(self, key: str) -> Any:
142 from airflow.models.connection import Connection
143
144 self.var = Connection.get_connection_from_secrets(key)
145 return self.var
146
147 def __repr__(self) -> str:
148 return str(self.var)
149
150 def get(self, key: str, default_conn: Any = None) -> Any:
151 from airflow.exceptions import AirflowNotFoundException
152 from airflow.models.connection import Connection
153
154 try:
155 return Connection.get_connection_from_secrets(key)
156 except AirflowNotFoundException:
157 return default_conn
158
159
160@attrs.define()
161class OutletEventAccessor:
162 """Wrapper to access an outlet dataset event in template.
163
164 :meta private:
165 """
166
167 extra: dict[str, Any]
168
169
170class OutletEventAccessors(Mapping[str, OutletEventAccessor]):
171 """Lazy mapping of outlet dataset event accessors.
172
173 :meta private:
174 """
175
176 def __init__(self) -> None:
177 self._dict: dict[str, OutletEventAccessor] = {}
178
179 def __iter__(self) -> Iterator[str]:
180 return iter(self._dict)
181
182 def __len__(self) -> int:
183 return len(self._dict)
184
185 def __getitem__(self, key: str | Dataset) -> OutletEventAccessor:
186 if (uri := coerce_to_uri(key)) not in self._dict:
187 self._dict[uri] = OutletEventAccessor({})
188 return self._dict[uri]
189
190
191class LazyDatasetEventSelectSequence(LazySelectSequence[DatasetEvent]):
192 """List-like interface to lazily access DatasetEvent rows.
193
194 :meta private:
195 """
196
197 @staticmethod
198 def _rebuild_select(stmt: TextClause) -> Select:
199 return select(DatasetEvent).from_statement(stmt)
200
201 @staticmethod
202 def _process_row(row: Row) -> DatasetEvent:
203 return row[0]
204
205
206@attrs.define(init=False)
207class InletEventsAccessors(Mapping[str, LazyDatasetEventSelectSequence]):
208 """Lazy mapping for inlet dataset events accessors.
209
210 :meta private:
211 """
212
213 _inlets: list[Any]
214 _datasets: dict[str, Dataset]
215 _session: Session
216
217 def __init__(self, inlets: list, *, session: Session) -> None:
218 self._inlets = inlets
219 self._datasets = {inlet.uri: inlet for inlet in inlets if isinstance(inlet, Dataset)}
220 self._session = session
221
222 def __iter__(self) -> Iterator[str]:
223 return iter(self._inlets)
224
225 def __len__(self) -> int:
226 return len(self._inlets)
227
228 def __getitem__(self, key: int | str | Dataset) -> LazyDatasetEventSelectSequence:
229 if isinstance(key, int): # Support index access; it's easier for trivial cases.
230 dataset = self._inlets[key]
231 if not isinstance(dataset, Dataset):
232 raise IndexError(key)
233 else:
234 dataset = self._datasets[coerce_to_uri(key)]
235 return LazyDatasetEventSelectSequence.from_select(
236 select(DatasetEvent).join(DatasetEvent.dataset).where(DatasetModel.uri == dataset.uri),
237 order_by=[DatasetEvent.timestamp],
238 session=self._session,
239 )
240
241
242class AirflowContextDeprecationWarning(RemovedInAirflow3Warning):
243 """Warn for usage of deprecated context variables in a task."""
244
245
246def _create_deprecation_warning(key: str, replacements: list[str]) -> RemovedInAirflow3Warning:
247 message = f"Accessing {key!r} from the template is deprecated and will be removed in a future version."
248 if not replacements:
249 return AirflowContextDeprecationWarning(message)
250 display_except_last = ", ".join(repr(r) for r in replacements[:-1])
251 if display_except_last:
252 message += f" Please use {display_except_last} or {replacements[-1]!r} instead."
253 else:
254 message += f" Please use {replacements[-1]!r} instead."
255 return AirflowContextDeprecationWarning(message)
256
257
258class Context(MutableMapping[str, Any]):
259 """Jinja2 template context for task rendering.
260
261 This is a mapping (dict-like) class that can lazily emit warnings when
262 (and only when) deprecated context keys are accessed.
263 """
264
265 _DEPRECATION_REPLACEMENTS: dict[str, list[str]] = {
266 "execution_date": ["data_interval_start", "logical_date"],
267 "next_ds": ["{{ data_interval_end | ds }}"],
268 "next_ds_nodash": ["{{ data_interval_end | ds_nodash }}"],
269 "next_execution_date": ["data_interval_end"],
270 "prev_ds": [],
271 "prev_ds_nodash": [],
272 "prev_execution_date": [],
273 "prev_execution_date_success": ["prev_data_interval_start_success"],
274 "tomorrow_ds": [],
275 "tomorrow_ds_nodash": [],
276 "yesterday_ds": [],
277 "yesterday_ds_nodash": [],
278 }
279
280 def __init__(self, context: MutableMapping[str, Any] | None = None, **kwargs: Any) -> None:
281 self._context: MutableMapping[str, Any] = context or {}
282 if kwargs:
283 self._context.update(kwargs)
284 self._deprecation_replacements = self._DEPRECATION_REPLACEMENTS.copy()
285
286 def __repr__(self) -> str:
287 return repr(self._context)
288
289 def __reduce_ex__(self, protocol: SupportsIndex) -> tuple[Any, ...]:
290 """Pickle the context as a dict.
291
292 We are intentionally going through ``__getitem__`` in this function,
293 instead of using ``items()``, to trigger deprecation warnings.
294 """
295 items = [(key, self[key]) for key in self._context]
296 return dict, (items,)
297
298 def __copy__(self) -> Context:
299 new = type(self)(copy.copy(self._context))
300 new._deprecation_replacements = self._deprecation_replacements.copy()
301 return new
302
303 def __getitem__(self, key: str) -> Any:
304 with contextlib.suppress(KeyError):
305 warnings.warn(
306 _create_deprecation_warning(key, self._deprecation_replacements[key]),
307 stacklevel=2,
308 )
309 with contextlib.suppress(KeyError):
310 return self._context[key]
311 raise KeyError(key)
312
313 def __setitem__(self, key: str, value: Any) -> None:
314 self._deprecation_replacements.pop(key, None)
315 self._context[key] = value
316
317 def __delitem__(self, key: str) -> None:
318 self._deprecation_replacements.pop(key, None)
319 del self._context[key]
320
321 def __contains__(self, key: object) -> bool:
322 return key in self._context
323
324 def __iter__(self) -> Iterator[str]:
325 return iter(self._context)
326
327 def __len__(self) -> int:
328 return len(self._context)
329
330 def __eq__(self, other: Any) -> bool:
331 if not isinstance(other, Context):
332 return NotImplemented
333 return self._context == other._context
334
335 def __ne__(self, other: Any) -> bool:
336 if not isinstance(other, Context):
337 return NotImplemented
338 return self._context != other._context
339
340 def keys(self) -> KeysView[str]:
341 return self._context.keys()
342
343 def items(self):
344 return ItemsView(self._context)
345
346 def values(self):
347 return ValuesView(self._context)
348
349
350def context_merge(context: Context, *args: Any, **kwargs: Any) -> None:
351 """Merge parameters into an existing context.
352
353 Like ``dict.update()`` , this take the same parameters, and updates
354 ``context`` in-place.
355
356 This is implemented as a free function because the ``Context`` type is
357 "faked" as a ``TypedDict`` in ``context.pyi``, which cannot have custom
358 functions.
359
360 :meta private:
361 """
362 context.update(*args, **kwargs)
363
364
365def context_update_for_unmapped(context: Context, task: BaseOperator) -> None:
366 """Update context after task unmapping.
367
368 Since ``get_template_context()`` is called before unmapping, the context
369 contains information about the mapped task. We need to do some in-place
370 updates to ensure the template context reflects the unmapped task instead.
371
372 :meta private:
373 """
374 from airflow.models.param import process_params
375
376 context["task"] = context["ti"].task = task
377 context["params"] = process_params(context["dag"], task, context["dag_run"], suppress_exception=False)
378
379
380def context_copy_partial(source: Context, keys: Container[str]) -> Context:
381 """Create a context by copying items under selected keys in ``source``.
382
383 This is implemented as a free function because the ``Context`` type is
384 "faked" as a ``TypedDict`` in ``context.pyi``, which cannot have custom
385 functions.
386
387 :meta private:
388 """
389 new = Context({k: v for k, v in source._context.items() if k in keys})
390 new._deprecation_replacements = source._deprecation_replacements.copy()
391 return new
392
393
394def lazy_mapping_from_context(source: Context) -> Mapping[str, Any]:
395 """Create a mapping that wraps deprecated entries in a lazy object proxy.
396
397 This further delays deprecation warning to until when the entry is actually
398 used, instead of when it's accessed in the context. The result is useful for
399 passing into a callable with ``**kwargs``, which would unpack the mapping
400 too eagerly otherwise.
401
402 This is implemented as a free function because the ``Context`` type is
403 "faked" as a ``TypedDict`` in ``context.pyi``, which cannot have custom
404 functions.
405
406 :meta private:
407 """
408 if not isinstance(source, Context):
409 # Sometimes we are passed a plain dict (usually in tests, or in User's
410 # custom operators) -- be lienent about what we accept so we don't
411 # break anything for users.
412 return source
413
414 def _deprecated_proxy_factory(k: str, v: Any) -> Any:
415 replacements = source._deprecation_replacements[k]
416 warnings.warn(_create_deprecation_warning(k, replacements), stacklevel=2)
417 return v
418
419 def _create_value(k: str, v: Any) -> Any:
420 if k not in source._deprecation_replacements:
421 return v
422 factory = functools.partial(_deprecated_proxy_factory, k, v)
423 return lazy_object_proxy.Proxy(factory)
424
425 return {k: _create_value(k, v) for k, v in source._context.items()}
426
427
428def context_get_outlet_events(context: Context) -> OutletEventAccessors:
429 try:
430 return context["outlet_events"]
431 except KeyError:
432 return OutletEventAccessors()