Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/airflow/utils/context.py: 36%
125 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-07 06:35 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-07 06:35 +0000
1#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18"""Jinja2 template rendering context helper."""
19from __future__ import annotations
21import contextlib
22import copy
23import functools
24import warnings
25from typing import (
26 TYPE_CHECKING,
27 Any,
28 Container,
29 ItemsView,
30 Iterator,
31 KeysView,
32 Mapping,
33 MutableMapping,
34 SupportsIndex,
35 ValuesView,
36)
38import lazy_object_proxy
40from airflow.exceptions import RemovedInAirflow3Warning
41from airflow.utils.types import NOTSET
43if TYPE_CHECKING:
44 from airflow.models.baseoperator import BaseOperator
46# NOTE: Please keep this in sync with Context in airflow/utils/context.pyi.
47KNOWN_CONTEXT_KEYS = {
48 "conf",
49 "conn",
50 "dag",
51 "dag_run",
52 "data_interval_end",
53 "data_interval_start",
54 "ds",
55 "ds_nodash",
56 "execution_date",
57 "expanded_ti_count",
58 "exception",
59 "inlets",
60 "logical_date",
61 "macros",
62 "next_ds",
63 "next_ds_nodash",
64 "next_execution_date",
65 "outlets",
66 "params",
67 "prev_data_interval_start_success",
68 "prev_data_interval_end_success",
69 "prev_ds",
70 "prev_ds_nodash",
71 "prev_execution_date",
72 "prev_execution_date_success",
73 "prev_start_date_success",
74 "run_id",
75 "task",
76 "task_instance",
77 "task_instance_key_str",
78 "test_mode",
79 "templates_dict",
80 "ti",
81 "tomorrow_ds",
82 "tomorrow_ds_nodash",
83 "triggering_dataset_events",
84 "ts",
85 "ts_nodash",
86 "ts_nodash_with_tz",
87 "try_number",
88 "var",
89 "yesterday_ds",
90 "yesterday_ds_nodash",
91}
94class VariableAccessor:
95 """Wrapper to access Variable values in template."""
97 def __init__(self, *, deserialize_json: bool) -> None:
98 self._deserialize_json = deserialize_json
99 self.var: Any = None
101 def __getattr__(self, key: str) -> Any:
102 from airflow.models.variable import Variable
104 self.var = Variable.get(key, deserialize_json=self._deserialize_json)
105 return self.var
107 def __repr__(self) -> str:
108 return str(self.var)
110 def get(self, key, default: Any = NOTSET) -> Any:
111 from airflow.models.variable import Variable
113 if default is NOTSET:
114 return Variable.get(key, deserialize_json=self._deserialize_json)
115 return Variable.get(key, default, deserialize_json=self._deserialize_json)
118class ConnectionAccessor:
119 """Wrapper to access Connection entries in template."""
121 def __init__(self) -> None:
122 self.var: Any = None
124 def __getattr__(self, key: str) -> Any:
125 from airflow.models.connection import Connection
127 self.var = Connection.get_connection_from_secrets(key)
128 return self.var
130 def __repr__(self) -> str:
131 return str(self.var)
133 def get(self, key: str, default_conn: Any = None) -> Any:
134 from airflow.exceptions import AirflowNotFoundException
135 from airflow.models.connection import Connection
137 try:
138 return Connection.get_connection_from_secrets(key)
139 except AirflowNotFoundException:
140 return default_conn
143class AirflowContextDeprecationWarning(RemovedInAirflow3Warning):
144 """Warn for usage of deprecated context variables in a task."""
147def _create_deprecation_warning(key: str, replacements: list[str]) -> RemovedInAirflow3Warning:
148 message = f"Accessing {key!r} from the template is deprecated and will be removed in a future version."
149 if not replacements:
150 return AirflowContextDeprecationWarning(message)
151 display_except_last = ", ".join(repr(r) for r in replacements[:-1])
152 if display_except_last:
153 message += f" Please use {display_except_last} or {replacements[-1]!r} instead."
154 else:
155 message += f" Please use {replacements[-1]!r} instead."
156 return AirflowContextDeprecationWarning(message)
159class Context(MutableMapping[str, Any]):
160 """Jinja2 template context for task rendering.
162 This is a mapping (dict-like) class that can lazily emit warnings when
163 (and only when) deprecated context keys are accessed.
164 """
166 _DEPRECATION_REPLACEMENTS: dict[str, list[str]] = {
167 "execution_date": ["data_interval_start", "logical_date"],
168 "next_ds": ["{{ data_interval_end | ds }}"],
169 "next_ds_nodash": ["{{ data_interval_end | ds_nodash }}"],
170 "next_execution_date": ["data_interval_end"],
171 "prev_ds": [],
172 "prev_ds_nodash": [],
173 "prev_execution_date": [],
174 "prev_execution_date_success": ["prev_data_interval_start_success"],
175 "tomorrow_ds": [],
176 "tomorrow_ds_nodash": [],
177 "yesterday_ds": [],
178 "yesterday_ds_nodash": [],
179 }
181 def __init__(self, context: MutableMapping[str, Any] | None = None, **kwargs: Any) -> None:
182 self._context: MutableMapping[str, Any] = context or {}
183 if kwargs:
184 self._context.update(kwargs)
185 self._deprecation_replacements = self._DEPRECATION_REPLACEMENTS.copy()
187 def __repr__(self) -> str:
188 return repr(self._context)
190 def __reduce_ex__(self, protocol: SupportsIndex) -> tuple[Any, ...]:
191 """Pickle the context as a dict.
193 We are intentionally going through ``__getitem__`` in this function,
194 instead of using ``items()``, to trigger deprecation warnings.
195 """
196 items = [(key, self[key]) for key in self._context]
197 return dict, (items,)
199 def __copy__(self) -> Context:
200 new = type(self)(copy.copy(self._context))
201 new._deprecation_replacements = self._deprecation_replacements.copy()
202 return new
204 def __getitem__(self, key: str) -> Any:
205 with contextlib.suppress(KeyError):
206 warnings.warn(_create_deprecation_warning(key, self._deprecation_replacements[key]))
207 with contextlib.suppress(KeyError):
208 return self._context[key]
209 raise KeyError(key)
211 def __setitem__(self, key: str, value: Any) -> None:
212 self._deprecation_replacements.pop(key, None)
213 self._context[key] = value
215 def __delitem__(self, key: str) -> None:
216 self._deprecation_replacements.pop(key, None)
217 del self._context[key]
219 def __contains__(self, key: object) -> bool:
220 return key in self._context
222 def __iter__(self) -> Iterator[str]:
223 return iter(self._context)
225 def __len__(self) -> int:
226 return len(self._context)
228 def __eq__(self, other: Any) -> bool:
229 if not isinstance(other, Context):
230 return NotImplemented
231 return self._context == other._context
233 def __ne__(self, other: Any) -> bool:
234 if not isinstance(other, Context):
235 return NotImplemented
236 return self._context != other._context
238 def keys(self) -> KeysView[str]:
239 return self._context.keys()
241 def items(self):
242 return ItemsView(self._context)
244 def values(self):
245 return ValuesView(self._context)
248def context_merge(context: Context, *args: Any, **kwargs: Any) -> None:
249 """Merge parameters into an existing context.
251 Like ``dict.update()`` , this take the same parameters, and updates
252 ``context`` in-place.
254 This is implemented as a free function because the ``Context`` type is
255 "faked" as a ``TypedDict`` in ``context.pyi``, which cannot have custom
256 functions.
258 :meta private:
259 """
260 context.update(*args, **kwargs)
263def context_update_for_unmapped(context: Context, task: BaseOperator) -> None:
264 """Update context after task unmapping.
266 Since ``get_template_context()`` is called before unmapping, the context
267 contains information about the mapped task. We need to do some in-place
268 updates to ensure the template context reflects the unmapped task instead.
270 :meta private:
271 """
272 from airflow.models.param import process_params
274 context["task"] = context["ti"].task = task
275 context["params"] = process_params(context["dag"], task, context["dag_run"], suppress_exception=False)
278def context_copy_partial(source: Context, keys: Container[str]) -> Context:
279 """Create a context by copying items under selected keys in ``source``.
281 This is implemented as a free function because the ``Context`` type is
282 "faked" as a ``TypedDict`` in ``context.pyi``, which cannot have custom
283 functions.
285 :meta private:
286 """
287 new = Context({k: v for k, v in source._context.items() if k in keys})
288 new._deprecation_replacements = source._deprecation_replacements.copy()
289 return new
292def lazy_mapping_from_context(source: Context) -> Mapping[str, Any]:
293 """Create a mapping that wraps deprecated entries in a lazy object proxy.
295 This further delays deprecation warning to until when the entry is actually
296 used, instead of when it's accessed in the context. The result is useful for
297 passing into a callable with ``**kwargs``, which would unpack the mapping
298 too eagerly otherwise.
300 This is implemented as a free function because the ``Context`` type is
301 "faked" as a ``TypedDict`` in ``context.pyi``, which cannot have custom
302 functions.
304 :meta private:
305 """
306 if not isinstance(source, Context):
307 # Sometimes we are passed a plain dict (usually in tests, or in User's
308 # custom operators) -- be lienent about what we accept so we don't
309 # break anything for users.
310 return source
312 def _deprecated_proxy_factory(k: str, v: Any) -> Any:
313 replacements = source._deprecation_replacements[k]
314 warnings.warn(_create_deprecation_warning(k, replacements))
315 return v
317 def _create_value(k: str, v: Any) -> Any:
318 if k not in source._deprecation_replacements:
319 return v
320 factory = functools.partial(_deprecated_proxy_factory, k, v)
321 return lazy_object_proxy.Proxy(factory)
323 return {k: _create_value(k, v) for k, v in source._context.items()}