1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements. See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership. The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License. You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied. See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18from __future__ import annotations
19
20import datetime
21import logging
22from collections.abc import Collection, Iterable, Sequence
23from dataclasses import dataclass
24from typing import TYPE_CHECKING, Any
25
26import jinja2
27import jinja2.nativetypes
28import jinja2.sandbox
29
30from airflow.sdk import ObjectStoragePath
31from airflow.sdk.definitions._internal.mixins import ResolveMixin
32from airflow.sdk.definitions.context import render_template_as_native, render_template_to_string
33
34if TYPE_CHECKING:
35 from airflow.sdk.definitions.context import Context
36 from airflow.sdk.definitions.dag import DAG
37 from airflow.sdk.types import Operator
38
39
40@dataclass(frozen=True)
41class LiteralValue(ResolveMixin):
42 """
43 A wrapper for a value that should be rendered as-is, without applying jinja templating to its contents.
44
45 :param value: The value to be rendered without templating
46 """
47
48 value: Any
49
50 def iter_references(self) -> Iterable[tuple[Operator, str]]:
51 return ()
52
53 def resolve(self, context: Context) -> Any:
54 return self.value
55
56
57log = logging.getLogger(__name__)
58
59
60class Templater:
61 """
62 This renders the template fields of object.
63
64 :meta private:
65 """
66
67 # For derived classes to define which fields will get jinjaified.
68 template_fields: Collection[str]
69 # Defines which files extensions to look for in the templated fields.
70 template_ext: Sequence[str]
71
72 def get_template_env(self, dag: DAG | None = None) -> jinja2.Environment:
73 """Fetch a Jinja template environment from the Dag or instantiate empty environment if no Dag."""
74 # This is imported locally since Jinja2 is heavy and we don't need it
75 # for most of the functionalities. It is imported by get_template_env()
76 # though, so we don't need to put this after the 'if dag' check.
77
78 if dag:
79 return dag.get_template_env(force_sandboxed=False)
80 return SandboxedEnvironment(cache_size=0)
81
82 def prepare_template(self) -> None:
83 """
84 Execute after the templated fields get replaced by their content.
85
86 If you need your object to alter the content of the file before the
87 template is rendered, it should override this method to do so.
88 """
89
90 def resolve_template_files(self) -> None:
91 """Get the content of files for template_field / template_ext."""
92 if self.template_ext:
93 for field in self.template_fields:
94 content = getattr(self, field, None)
95 if isinstance(content, str) and content.endswith(tuple(self.template_ext)):
96 env = self.get_template_env()
97 try:
98 setattr(self, field, env.loader.get_source(env, content)[0]) # type: ignore
99 except Exception:
100 log.exception("Failed to resolve template field %r", field)
101 elif isinstance(content, list):
102 env = self.get_template_env()
103 for i, item in enumerate(content):
104 if isinstance(item, str) and item.endswith(tuple(self.template_ext)):
105 try:
106 content[i] = env.loader.get_source(env, item)[0] # type: ignore
107 except Exception:
108 log.exception("Failed to get source %s", item)
109 self.prepare_template()
110
111 def _should_render_native(self, dag: DAG | None = None) -> bool:
112 # Operator explicitly set? Use that value, otherwise inherit from DAG
113 render_op_template_as_native_obj = getattr(self, "render_template_as_native_obj", None)
114 if render_op_template_as_native_obj is not None:
115 return render_op_template_as_native_obj
116
117 return dag.render_template_as_native_obj if dag else False
118
119 def _do_render_template_fields(
120 self,
121 parent: Any,
122 template_fields: Iterable[str],
123 context: Context,
124 jinja_env: jinja2.Environment,
125 seen_oids: set[int],
126 ) -> None:
127 for attr_name in template_fields:
128 value = getattr(parent, attr_name)
129 rendered_content = self.render_template(
130 value,
131 context,
132 jinja_env,
133 seen_oids,
134 )
135 if rendered_content:
136 setattr(parent, attr_name, rendered_content)
137
138 def _render(self, template, context, dag=None) -> Any:
139 if self._should_render_native(dag):
140 return render_template_as_native(template, context)
141 return render_template_to_string(template, context)
142
143 def render_template(
144 self,
145 content: Any,
146 context: Context,
147 jinja_env: jinja2.Environment | None = None,
148 seen_oids: set[int] | None = None,
149 ) -> Any:
150 """
151 Render a templated string.
152
153 If *content* is a collection holding multiple templated strings, strings
154 in the collection will be templated recursively.
155
156 :param content: Content to template. Only strings can be templated (may
157 be inside a collection).
158 :param context: Dict with values to apply on templated content
159 :param jinja_env: Jinja environment. Can be provided to avoid
160 re-creating Jinja environments during recursion.
161 :param seen_oids: template fields already rendered (to avoid
162 *RecursionError* on circular dependencies)
163 :return: Templated content
164 """
165 # "content" is a bad name, but we're stuck to it being public API.
166 value = content
167 del content
168
169 if seen_oids is not None:
170 oids = seen_oids
171 else:
172 oids = set()
173
174 if id(value) in oids:
175 return value
176
177 if not jinja_env:
178 jinja_env = self.get_template_env()
179
180 if isinstance(value, str):
181 if value.endswith(tuple(self.template_ext)): # A filepath.
182 template = jinja_env.get_template(value)
183 else:
184 template = jinja_env.from_string(value)
185 return self._render(template, context)
186 if isinstance(value, ObjectStoragePath):
187 return self._render_object_storage_path(value, context, jinja_env)
188
189 if resolve := getattr(value, "resolve", None):
190 return resolve(context)
191
192 # Fast path for common built-in collections.
193 if value.__class__ is tuple:
194 return tuple(self.render_template(element, context, jinja_env, oids) for element in value)
195 if isinstance(value, tuple): # Special case for named tuples.
196 return value.__class__(*(self.render_template(el, context, jinja_env, oids) for el in value))
197 if isinstance(value, list):
198 return [self.render_template(element, context, jinja_env, oids) for element in value]
199 if isinstance(value, dict):
200 return {k: self.render_template(v, context, jinja_env, oids) for k, v in value.items()}
201 if isinstance(value, set):
202 return {self.render_template(element, context, jinja_env, oids) for element in value}
203
204 # More complex collections.
205 self._render_nested_template_fields(value, context, jinja_env, oids)
206 return value
207
208 def _render_object_storage_path(
209 self, value: ObjectStoragePath, context: Context, jinja_env: jinja2.Environment
210 ) -> ObjectStoragePath:
211 serialized_path = value.serialize()
212 path_version = value.__version__
213 serialized_path["path"] = self._render(jinja_env.from_string(serialized_path["path"]), context)
214 return value.deserialize(data=serialized_path, version=path_version)
215
216 def _render_nested_template_fields(
217 self,
218 value: Any,
219 context: Context,
220 jinja_env: jinja2.Environment,
221 seen_oids: set[int],
222 ) -> None:
223 if id(value) in seen_oids:
224 return
225 seen_oids.add(id(value))
226 try:
227 nested_template_fields = value.template_fields
228 except AttributeError:
229 # content has no inner template fields
230 return
231 self._do_render_template_fields(value, nested_template_fields, context, jinja_env, seen_oids)
232
233
234class _AirflowEnvironmentMixin:
235 def __init__(self, **kwargs):
236 super().__init__(**kwargs)
237
238 self.filters.update(FILTERS)
239
240 def is_safe_attribute(self, obj, attr, value):
241 """
242 Allow access to ``_`` prefix vars (but not ``__``).
243
244 Unlike the stock SandboxedEnvironment, we allow access to "private" attributes (ones starting with
245 ``_``) whilst still blocking internal or truly private attributes (``__`` prefixed ones).
246 """
247 return not jinja2.sandbox.is_internal_attribute(obj, attr)
248
249
250class NativeEnvironment(_AirflowEnvironmentMixin, jinja2.nativetypes.NativeEnvironment):
251 """NativeEnvironment for Airflow task templates."""
252
253
254class SandboxedEnvironment(_AirflowEnvironmentMixin, jinja2.sandbox.SandboxedEnvironment):
255 """SandboxedEnvironment for Airflow task templates."""
256
257
258def ds_filter(value: datetime.date | datetime.time | None) -> str | None:
259 """Date filter."""
260 if value is None:
261 return None
262 return value.strftime("%Y-%m-%d")
263
264
265def ds_nodash_filter(value: datetime.date | datetime.time | None) -> str | None:
266 """Date filter without dashes."""
267 if value is None:
268 return None
269 return value.strftime("%Y%m%d")
270
271
272def ts_filter(value: datetime.date | datetime.time | None) -> str | None:
273 """Timestamp filter."""
274 if value is None:
275 return None
276 return value.isoformat()
277
278
279def ts_nodash_filter(value: datetime.date | datetime.time | None) -> str | None:
280 """Timestamp filter without dashes."""
281 if value is None:
282 return None
283 return value.strftime("%Y%m%dT%H%M%S")
284
285
286def ts_nodash_with_tz_filter(value: datetime.date | datetime.time | None) -> str | None:
287 """Timestamp filter with timezone."""
288 if value is None:
289 return None
290 return value.isoformat().replace("-", "").replace(":", "")
291
292
293FILTERS = {
294 "ds": ds_filter,
295 "ds_nodash": ds_nodash_filter,
296 "ts": ts_filter,
297 "ts_nodash": ts_nodash_filter,
298 "ts_nodash_with_tz": ts_nodash_with_tz_filter,
299}
300
301
302def create_template_env(
303 *,
304 native: bool = False,
305 searchpath: list[str] | None = None,
306 template_undefined: type[jinja2.StrictUndefined] = jinja2.StrictUndefined,
307 jinja_environment_kwargs: dict | None = None,
308 user_defined_macros: dict | None = None,
309 user_defined_filters: dict | None = None,
310) -> jinja2.Environment:
311 """Create a Jinja2 environment with the given settings."""
312 # Default values (for backward compatibility)
313 jinja_env_options = {
314 "undefined": template_undefined,
315 "extensions": ["jinja2.ext.do"],
316 "cache_size": 0,
317 }
318 if searchpath:
319 jinja_env_options["loader"] = jinja2.FileSystemLoader(searchpath)
320 if jinja_environment_kwargs:
321 jinja_env_options.update(jinja_environment_kwargs)
322
323 env = NativeEnvironment(**jinja_env_options) if native else SandboxedEnvironment(**jinja_env_options)
324
325 # Add any user defined items. Safe to edit globals as long as no templates are rendered yet.
326 # http://jinja.pocoo.org/docs/2.10/api/#jinja2.Environment.globals
327 if user_defined_macros:
328 env.globals.update(user_defined_macros)
329 if user_defined_filters:
330 env.filters.update(user_defined_filters)
331
332 return env