Coverage for /pythoncovmergedfiles/medio/medio/src/airflow/airflow/utils/context.py: 41%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18"""Jinja2 template rendering context helper."""
20from __future__ import annotations
22import contextlib
23import copy
24import functools
25import warnings
26from typing import (
27 TYPE_CHECKING,
28 Any,
29 Container,
30 ItemsView,
31 Iterator,
32 KeysView,
33 Mapping,
34 MutableMapping,
35 SupportsIndex,
36 ValuesView,
37)
39import attrs
40import lazy_object_proxy
41from sqlalchemy import select
43from airflow.datasets import Dataset, coerce_to_uri
44from airflow.exceptions import RemovedInAirflow3Warning
45from airflow.models.dataset import DatasetEvent, DatasetModel
46from airflow.utils.db import LazySelectSequence
47from airflow.utils.types import NOTSET
49if TYPE_CHECKING:
50 from sqlalchemy.engine import Row
51 from sqlalchemy.orm import Session
52 from sqlalchemy.sql.expression import Select, TextClause
54 from airflow.models.baseoperator import BaseOperator
56# NOTE: Please keep this in sync with the following:
57# * Context in airflow/utils/context.pyi.
58# * Table in docs/apache-airflow/templates-ref.rst
59KNOWN_CONTEXT_KEYS: set[str] = {
60 "conf",
61 "conn",
62 "dag",
63 "dag_run",
64 "data_interval_end",
65 "data_interval_start",
66 "ds",
67 "ds_nodash",
68 "execution_date",
69 "expanded_ti_count",
70 "exception",
71 "inlets",
72 "inlet_events",
73 "logical_date",
74 "macros",
75 "map_index_template",
76 "next_ds",
77 "next_ds_nodash",
78 "next_execution_date",
79 "outlets",
80 "outlet_events",
81 "params",
82 "prev_data_interval_start_success",
83 "prev_data_interval_end_success",
84 "prev_ds",
85 "prev_ds_nodash",
86 "prev_execution_date",
87 "prev_execution_date_success",
88 "prev_start_date_success",
89 "prev_end_date_success",
90 "reason",
91 "run_id",
92 "task",
93 "task_instance",
94 "task_instance_key_str",
95 "test_mode",
96 "templates_dict",
97 "ti",
98 "tomorrow_ds",
99 "tomorrow_ds_nodash",
100 "triggering_dataset_events",
101 "ts",
102 "ts_nodash",
103 "ts_nodash_with_tz",
104 "try_number",
105 "var",
106 "yesterday_ds",
107 "yesterday_ds_nodash",
108}
111class VariableAccessor:
112 """Wrapper to access Variable values in template."""
114 def __init__(self, *, deserialize_json: bool) -> None:
115 self._deserialize_json = deserialize_json
116 self.var: Any = None
118 def __getattr__(self, key: str) -> Any:
119 from airflow.models.variable import Variable
121 self.var = Variable.get(key, deserialize_json=self._deserialize_json)
122 return self.var
124 def __repr__(self) -> str:
125 return str(self.var)
127 def get(self, key, default: Any = NOTSET) -> Any:
128 from airflow.models.variable import Variable
130 if default is NOTSET:
131 return Variable.get(key, deserialize_json=self._deserialize_json)
132 return Variable.get(key, default, deserialize_json=self._deserialize_json)
135class ConnectionAccessor:
136 """Wrapper to access Connection entries in template."""
138 def __init__(self) -> None:
139 self.var: Any = None
141 def __getattr__(self, key: str) -> Any:
142 from airflow.models.connection import Connection
144 self.var = Connection.get_connection_from_secrets(key)
145 return self.var
147 def __repr__(self) -> str:
148 return str(self.var)
150 def get(self, key: str, default_conn: Any = None) -> Any:
151 from airflow.exceptions import AirflowNotFoundException
152 from airflow.models.connection import Connection
154 try:
155 return Connection.get_connection_from_secrets(key)
156 except AirflowNotFoundException:
157 return default_conn
160@attrs.define()
161class OutletEventAccessor:
162 """Wrapper to access an outlet dataset event in template.
164 :meta private:
165 """
167 extra: dict[str, Any]
170class OutletEventAccessors(Mapping[str, OutletEventAccessor]):
171 """Lazy mapping of outlet dataset event accessors.
173 :meta private:
174 """
176 def __init__(self) -> None:
177 self._dict: dict[str, OutletEventAccessor] = {}
179 def __iter__(self) -> Iterator[str]:
180 return iter(self._dict)
182 def __len__(self) -> int:
183 return len(self._dict)
185 def __getitem__(self, key: str | Dataset) -> OutletEventAccessor:
186 if (uri := coerce_to_uri(key)) not in self._dict:
187 self._dict[uri] = OutletEventAccessor({})
188 return self._dict[uri]
191class LazyDatasetEventSelectSequence(LazySelectSequence[DatasetEvent]):
192 """List-like interface to lazily access DatasetEvent rows.
194 :meta private:
195 """
197 @staticmethod
198 def _rebuild_select(stmt: TextClause) -> Select:
199 return select(DatasetEvent).from_statement(stmt)
201 @staticmethod
202 def _process_row(row: Row) -> DatasetEvent:
203 return row[0]
206@attrs.define(init=False)
207class InletEventsAccessors(Mapping[str, LazyDatasetEventSelectSequence]):
208 """Lazy mapping for inlet dataset events accessors.
210 :meta private:
211 """
213 _inlets: list[Any]
214 _datasets: dict[str, Dataset]
215 _session: Session
217 def __init__(self, inlets: list, *, session: Session) -> None:
218 self._inlets = inlets
219 self._datasets = {inlet.uri: inlet for inlet in inlets if isinstance(inlet, Dataset)}
220 self._session = session
222 def __iter__(self) -> Iterator[str]:
223 return iter(self._inlets)
225 def __len__(self) -> int:
226 return len(self._inlets)
228 def __getitem__(self, key: int | str | Dataset) -> LazyDatasetEventSelectSequence:
229 if isinstance(key, int): # Support index access; it's easier for trivial cases.
230 dataset = self._inlets[key]
231 if not isinstance(dataset, Dataset):
232 raise IndexError(key)
233 else:
234 dataset = self._datasets[coerce_to_uri(key)]
235 return LazyDatasetEventSelectSequence.from_select(
236 select(DatasetEvent).join(DatasetEvent.dataset).where(DatasetModel.uri == dataset.uri),
237 order_by=[DatasetEvent.timestamp],
238 session=self._session,
239 )
242class AirflowContextDeprecationWarning(RemovedInAirflow3Warning):
243 """Warn for usage of deprecated context variables in a task."""
246def _create_deprecation_warning(key: str, replacements: list[str]) -> RemovedInAirflow3Warning:
247 message = f"Accessing {key!r} from the template is deprecated and will be removed in a future version."
248 if not replacements:
249 return AirflowContextDeprecationWarning(message)
250 display_except_last = ", ".join(repr(r) for r in replacements[:-1])
251 if display_except_last:
252 message += f" Please use {display_except_last} or {replacements[-1]!r} instead."
253 else:
254 message += f" Please use {replacements[-1]!r} instead."
255 return AirflowContextDeprecationWarning(message)
258class Context(MutableMapping[str, Any]):
259 """Jinja2 template context for task rendering.
261 This is a mapping (dict-like) class that can lazily emit warnings when
262 (and only when) deprecated context keys are accessed.
263 """
265 _DEPRECATION_REPLACEMENTS: dict[str, list[str]] = {
266 "execution_date": ["data_interval_start", "logical_date"],
267 "next_ds": ["{{ data_interval_end | ds }}"],
268 "next_ds_nodash": ["{{ data_interval_end | ds_nodash }}"],
269 "next_execution_date": ["data_interval_end"],
270 "prev_ds": [],
271 "prev_ds_nodash": [],
272 "prev_execution_date": [],
273 "prev_execution_date_success": ["prev_data_interval_start_success"],
274 "tomorrow_ds": [],
275 "tomorrow_ds_nodash": [],
276 "yesterday_ds": [],
277 "yesterday_ds_nodash": [],
278 }
280 def __init__(self, context: MutableMapping[str, Any] | None = None, **kwargs: Any) -> None:
281 self._context: MutableMapping[str, Any] = context or {}
282 if kwargs:
283 self._context.update(kwargs)
284 self._deprecation_replacements = self._DEPRECATION_REPLACEMENTS.copy()
286 def __repr__(self) -> str:
287 return repr(self._context)
289 def __reduce_ex__(self, protocol: SupportsIndex) -> tuple[Any, ...]:
290 """Pickle the context as a dict.
292 We are intentionally going through ``__getitem__`` in this function,
293 instead of using ``items()``, to trigger deprecation warnings.
294 """
295 items = [(key, self[key]) for key in self._context]
296 return dict, (items,)
298 def __copy__(self) -> Context:
299 new = type(self)(copy.copy(self._context))
300 new._deprecation_replacements = self._deprecation_replacements.copy()
301 return new
303 def __getitem__(self, key: str) -> Any:
304 with contextlib.suppress(KeyError):
305 warnings.warn(
306 _create_deprecation_warning(key, self._deprecation_replacements[key]),
307 stacklevel=2,
308 )
309 with contextlib.suppress(KeyError):
310 return self._context[key]
311 raise KeyError(key)
313 def __setitem__(self, key: str, value: Any) -> None:
314 self._deprecation_replacements.pop(key, None)
315 self._context[key] = value
317 def __delitem__(self, key: str) -> None:
318 self._deprecation_replacements.pop(key, None)
319 del self._context[key]
321 def __contains__(self, key: object) -> bool:
322 return key in self._context
324 def __iter__(self) -> Iterator[str]:
325 return iter(self._context)
327 def __len__(self) -> int:
328 return len(self._context)
330 def __eq__(self, other: Any) -> bool:
331 if not isinstance(other, Context):
332 return NotImplemented
333 return self._context == other._context
335 def __ne__(self, other: Any) -> bool:
336 if not isinstance(other, Context):
337 return NotImplemented
338 return self._context != other._context
340 def keys(self) -> KeysView[str]:
341 return self._context.keys()
343 def items(self):
344 return ItemsView(self._context)
346 def values(self):
347 return ValuesView(self._context)
350def context_merge(context: Context, *args: Any, **kwargs: Any) -> None:
351 """Merge parameters into an existing context.
353 Like ``dict.update()`` , this take the same parameters, and updates
354 ``context`` in-place.
356 This is implemented as a free function because the ``Context`` type is
357 "faked" as a ``TypedDict`` in ``context.pyi``, which cannot have custom
358 functions.
360 :meta private:
361 """
362 context.update(*args, **kwargs)
365def context_update_for_unmapped(context: Context, task: BaseOperator) -> None:
366 """Update context after task unmapping.
368 Since ``get_template_context()`` is called before unmapping, the context
369 contains information about the mapped task. We need to do some in-place
370 updates to ensure the template context reflects the unmapped task instead.
372 :meta private:
373 """
374 from airflow.models.param import process_params
376 context["task"] = context["ti"].task = task
377 context["params"] = process_params(context["dag"], task, context["dag_run"], suppress_exception=False)
380def context_copy_partial(source: Context, keys: Container[str]) -> Context:
381 """Create a context by copying items under selected keys in ``source``.
383 This is implemented as a free function because the ``Context`` type is
384 "faked" as a ``TypedDict`` in ``context.pyi``, which cannot have custom
385 functions.
387 :meta private:
388 """
389 new = Context({k: v for k, v in source._context.items() if k in keys})
390 new._deprecation_replacements = source._deprecation_replacements.copy()
391 return new
394def lazy_mapping_from_context(source: Context) -> Mapping[str, Any]:
395 """Create a mapping that wraps deprecated entries in a lazy object proxy.
397 This further delays deprecation warning to until when the entry is actually
398 used, instead of when it's accessed in the context. The result is useful for
399 passing into a callable with ``**kwargs``, which would unpack the mapping
400 too eagerly otherwise.
402 This is implemented as a free function because the ``Context`` type is
403 "faked" as a ``TypedDict`` in ``context.pyi``, which cannot have custom
404 functions.
406 :meta private:
407 """
408 if not isinstance(source, Context):
409 # Sometimes we are passed a plain dict (usually in tests, or in User's
410 # custom operators) -- be lienent about what we accept so we don't
411 # break anything for users.
412 return source
414 def _deprecated_proxy_factory(k: str, v: Any) -> Any:
415 replacements = source._deprecation_replacements[k]
416 warnings.warn(_create_deprecation_warning(k, replacements), stacklevel=2)
417 return v
419 def _create_value(k: str, v: Any) -> Any:
420 if k not in source._deprecation_replacements:
421 return v
422 factory = functools.partial(_deprecated_proxy_factory, k, v)
423 return lazy_object_proxy.Proxy(factory)
425 return {k: _create_value(k, v) for k, v in source._context.items()}
428def context_get_outlet_events(context: Context) -> OutletEventAccessors:
429 try:
430 return context["outlet_events"]
431 except KeyError:
432 return OutletEventAccessors()