Coverage for /pythoncovmergedfiles/medio/medio/src/airflow/build/lib/airflow/utils/context.py: 36%
125 statements
« prev ^ index » next coverage.py v7.0.1, created at 2022-12-25 06:11 +0000
« prev ^ index » next coverage.py v7.0.1, created at 2022-12-25 06:11 +0000
1#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18"""Jinja2 template rendering context helper."""
19from __future__ import annotations
21import contextlib
22import copy
23import functools
24import warnings
25from typing import (
26 TYPE_CHECKING,
27 Any,
28 Container,
29 ItemsView,
30 Iterator,
31 KeysView,
32 Mapping,
33 MutableMapping,
34 ValuesView,
35)
37import lazy_object_proxy
39from airflow.exceptions import RemovedInAirflow3Warning
40from airflow.utils.types import NOTSET
42if TYPE_CHECKING:
43 from airflow.models.baseoperator import BaseOperator
45# NOTE: Please keep this in sync with Context in airflow/utils/context.pyi.
46KNOWN_CONTEXT_KEYS = {
47 "conf",
48 "conn",
49 "dag",
50 "dag_run",
51 "data_interval_end",
52 "data_interval_start",
53 "ds",
54 "ds_nodash",
55 "execution_date",
56 "expanded_ti_count",
57 "exception",
58 "inlets",
59 "logical_date",
60 "macros",
61 "next_ds",
62 "next_ds_nodash",
63 "next_execution_date",
64 "outlets",
65 "params",
66 "prev_data_interval_start_success",
67 "prev_data_interval_end_success",
68 "prev_ds",
69 "prev_ds_nodash",
70 "prev_execution_date",
71 "prev_execution_date_success",
72 "prev_start_date_success",
73 "run_id",
74 "task",
75 "task_instance",
76 "task_instance_key_str",
77 "test_mode",
78 "templates_dict",
79 "ti",
80 "tomorrow_ds",
81 "tomorrow_ds_nodash",
82 "triggering_dataset_events",
83 "ts",
84 "ts_nodash",
85 "ts_nodash_with_tz",
86 "try_number",
87 "var",
88 "yesterday_ds",
89 "yesterday_ds_nodash",
90}
93class VariableAccessor:
94 """Wrapper to access Variable values in template."""
96 def __init__(self, *, deserialize_json: bool) -> None:
97 self._deserialize_json = deserialize_json
98 self.var: Any = None
100 def __getattr__(self, key: str) -> Any:
101 from airflow.models.variable import Variable
103 self.var = Variable.get(key, deserialize_json=self._deserialize_json)
104 return self.var
106 def __repr__(self) -> str:
107 return str(self.var)
109 def get(self, key, default: Any = NOTSET) -> Any:
110 from airflow.models.variable import Variable
112 if default is NOTSET:
113 return Variable.get(key, deserialize_json=self._deserialize_json)
114 return Variable.get(key, default, deserialize_json=self._deserialize_json)
117class ConnectionAccessor:
118 """Wrapper to access Connection entries in template."""
120 def __init__(self) -> None:
121 self.var: Any = None
123 def __getattr__(self, key: str) -> Any:
124 from airflow.models.connection import Connection
126 self.var = Connection.get_connection_from_secrets(key)
127 return self.var
129 def __repr__(self) -> str:
130 return str(self.var)
132 def get(self, key: str, default_conn: Any = None) -> Any:
133 from airflow.exceptions import AirflowNotFoundException
134 from airflow.models.connection import Connection
136 try:
137 return Connection.get_connection_from_secrets(key)
138 except AirflowNotFoundException:
139 return default_conn
142class AirflowContextDeprecationWarning(RemovedInAirflow3Warning):
143 """Warn for usage of deprecated context variables in a task."""
146def _create_deprecation_warning(key: str, replacements: list[str]) -> RemovedInAirflow3Warning:
147 message = f"Accessing {key!r} from the template is deprecated and will be removed in a future version."
148 if not replacements:
149 return AirflowContextDeprecationWarning(message)
150 display_except_last = ", ".join(repr(r) for r in replacements[:-1])
151 if display_except_last:
152 message += f" Please use {display_except_last} or {replacements[-1]!r} instead."
153 else:
154 message += f" Please use {replacements[-1]!r} instead."
155 return AirflowContextDeprecationWarning(message)
158class Context(MutableMapping[str, Any]):
159 """Jinja2 template context for task rendering.
161 This is a mapping (dict-like) class that can lazily emit warnings when
162 (and only when) deprecated context keys are accessed.
163 """
165 _DEPRECATION_REPLACEMENTS: dict[str, list[str]] = {
166 "execution_date": ["data_interval_start", "logical_date"],
167 "next_ds": ["{{ data_interval_end | ds }}"],
168 "next_ds_nodash": ["{{ data_interval_end | ds_nodash }}"],
169 "next_execution_date": ["data_interval_end"],
170 "prev_ds": [],
171 "prev_ds_nodash": [],
172 "prev_execution_date": [],
173 "prev_execution_date_success": ["prev_data_interval_start_success"],
174 "tomorrow_ds": [],
175 "tomorrow_ds_nodash": [],
176 "yesterday_ds": [],
177 "yesterday_ds_nodash": [],
178 }
180 def __init__(self, context: MutableMapping[str, Any] | None = None, **kwargs: Any) -> None:
181 self._context: MutableMapping[str, Any] = context or {}
182 if kwargs:
183 self._context.update(kwargs)
184 self._deprecation_replacements = self._DEPRECATION_REPLACEMENTS.copy()
186 def __repr__(self) -> str:
187 return repr(self._context)
189 def __reduce_ex__(self, protocol: int) -> tuple[Any, ...]:
190 """Pickle the context as a dict.
192 We are intentionally going through ``__getitem__`` in this function,
193 instead of using ``items()``, to trigger deprecation warnings.
194 """
195 items = [(key, self[key]) for key in self._context]
196 return dict, (items,)
198 def __copy__(self) -> Context:
199 new = type(self)(copy.copy(self._context))
200 new._deprecation_replacements = self._deprecation_replacements.copy()
201 return new
203 def __getitem__(self, key: str) -> Any:
204 with contextlib.suppress(KeyError):
205 warnings.warn(_create_deprecation_warning(key, self._deprecation_replacements[key]))
206 with contextlib.suppress(KeyError):
207 return self._context[key]
208 raise KeyError(key)
210 def __setitem__(self, key: str, value: Any) -> None:
211 self._deprecation_replacements.pop(key, None)
212 self._context[key] = value
214 def __delitem__(self, key: str) -> None:
215 self._deprecation_replacements.pop(key, None)
216 del self._context[key]
218 def __contains__(self, key: object) -> bool:
219 return key in self._context
221 def __iter__(self) -> Iterator[str]:
222 return iter(self._context)
224 def __len__(self) -> int:
225 return len(self._context)
227 def __eq__(self, other: Any) -> bool:
228 if not isinstance(other, Context):
229 return NotImplemented
230 return self._context == other._context
232 def __ne__(self, other: Any) -> bool:
233 if not isinstance(other, Context):
234 return NotImplemented
235 return self._context != other._context
237 def keys(self) -> KeysView[str]:
238 return self._context.keys()
240 def items(self):
241 return ItemsView(self._context)
243 def values(self):
244 return ValuesView(self._context)
247def context_merge(context: Context, *args: Any, **kwargs: Any) -> None:
248 """Merge parameters into an existing context.
250 Like ``dict.update()`` , this take the same parameters, and updates
251 ``context`` in-place.
253 This is implemented as a free function because the ``Context`` type is
254 "faked" as a ``TypedDict`` in ``context.pyi``, which cannot have custom
255 functions.
257 :meta private:
258 """
259 context.update(*args, **kwargs)
262def context_update_for_unmapped(context: Context, task: BaseOperator) -> None:
263 """Update context after task unmapping.
265 Since ``get_template_context()`` is called before unmapping, the context
266 contains information about the mapped task. We need to do some in-place
267 updates to ensure the template context reflects the unmapped task instead.
269 :meta private:
270 """
271 from airflow.models.param import process_params
273 context["task"] = context["ti"].task = task
274 context["params"] = process_params(context["dag"], task, context["dag_run"], suppress_exception=False)
277def context_copy_partial(source: Context, keys: Container[str]) -> Context:
278 """Create a context by copying items under selected keys in ``source``.
280 This is implemented as a free function because the ``Context`` type is
281 "faked" as a ``TypedDict`` in ``context.pyi``, which cannot have custom
282 functions.
284 :meta private:
285 """
286 new = Context({k: v for k, v in source._context.items() if k in keys})
287 new._deprecation_replacements = source._deprecation_replacements.copy()
288 return new
291def lazy_mapping_from_context(source: Context) -> Mapping[str, Any]:
292 """Create a mapping that wraps deprecated entries in a lazy object proxy.
294 This further delays deprecation warning to until when the entry is actually
295 used, instead of when it's accessed in the context. The result is useful for
296 passing into a callable with ``**kwargs``, which would unpack the mapping
297 too eagerly otherwise.
299 This is implemented as a free function because the ``Context`` type is
300 "faked" as a ``TypedDict`` in ``context.pyi``, which cannot have custom
301 functions.
303 :meta private:
304 """
305 if not isinstance(source, Context):
306 # Sometimes we are passed a plain dict (usually in tests, or in User's
307 # custom operators) -- be lienent about what we accept so we don't
308 # break anything for users.
309 return source
311 def _deprecated_proxy_factory(k: str, v: Any) -> Any:
312 replacements = source._deprecation_replacements[k]
313 warnings.warn(_create_deprecation_warning(k, replacements))
314 return v
316 def _create_value(k: str, v: Any) -> Any:
317 if k not in source._deprecation_replacements:
318 return v
319 factory = functools.partial(_deprecated_proxy_factory, k, v)
320 return lazy_object_proxy.Proxy(factory)
322 return {k: _create_value(k, v) for k, v in source._context.items()}