1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements. See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership. The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License. You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied. See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18from __future__ import annotations
19
20import datetime
21import logging
22from collections.abc import Collection, Iterable, Sequence
23from dataclasses import dataclass
24from typing import TYPE_CHECKING, Any
25
26import jinja2
27import jinja2.nativetypes
28import jinja2.sandbox
29
30from airflow.sdk import ObjectStoragePath
31from airflow.sdk.definitions._internal.mixins import ResolveMixin
32from airflow.sdk.definitions.context import render_template_as_native, render_template_to_string
33
34if TYPE_CHECKING:
35 from airflow.sdk.definitions.context import Context
36 from airflow.sdk.definitions.dag import DAG
37 from airflow.sdk.types import Operator
38
39
40@dataclass(frozen=True)
41class LiteralValue(ResolveMixin):
42 """
43 A wrapper for a value that should be rendered as-is, without applying jinja templating to its contents.
44
45 :param value: The value to be rendered without templating
46 """
47
48 value: Any
49
50 def iter_references(self) -> Iterable[tuple[Operator, str]]:
51 return ()
52
53 def resolve(self, context: Context) -> Any:
54 return self.value
55
56
57log = logging.getLogger(__name__)
58
59
60class Templater:
61 """
62 This renders the template fields of object.
63
64 :meta private:
65 """
66
67 # For derived classes to define which fields will get jinjaified.
68 template_fields: Collection[str]
69 # Defines which files extensions to look for in the templated fields.
70 template_ext: Sequence[str]
71
72 def get_template_env(self, dag: DAG | None = None) -> jinja2.Environment:
73 """Fetch a Jinja template environment from the Dag or instantiate empty environment if no Dag."""
74 # This is imported locally since Jinja2 is heavy and we don't need it
75 # for most of the functionalities. It is imported by get_template_env()
76 # though, so we don't need to put this after the 'if dag' check.
77
78 if dag:
79 return dag.get_template_env(force_sandboxed=False)
80 return SandboxedEnvironment(cache_size=0)
81
82 def prepare_template(self) -> None:
83 """
84 Execute after the templated fields get replaced by their content.
85
86 If you need your object to alter the content of the file before the
87 template is rendered, it should override this method to do so.
88 """
89
90 def resolve_template_files(self) -> None:
91 """Get the content of files for template_field / template_ext."""
92 if self.template_ext:
93 for field in self.template_fields:
94 content = getattr(self, field, None)
95 if isinstance(content, str) and content.endswith(tuple(self.template_ext)):
96 env = self.get_template_env()
97 try:
98 setattr(self, field, env.loader.get_source(env, content)[0]) # type: ignore
99 except Exception:
100 log.exception("Failed to resolve template field %r", field)
101 elif isinstance(content, list):
102 env = self.get_template_env()
103 for i, item in enumerate(content):
104 if isinstance(item, str) and item.endswith(tuple(self.template_ext)):
105 try:
106 content[i] = env.loader.get_source(env, item)[0] # type: ignore
107 except Exception:
108 log.exception("Failed to get source %s", item)
109 self.prepare_template()
110
111 def _do_render_template_fields(
112 self,
113 parent: Any,
114 template_fields: Iterable[str],
115 context: Context,
116 jinja_env: jinja2.Environment,
117 seen_oids: set[int],
118 ) -> None:
119 for attr_name in template_fields:
120 value = getattr(parent, attr_name)
121 rendered_content = self.render_template(
122 value,
123 context,
124 jinja_env,
125 seen_oids,
126 )
127 if rendered_content:
128 setattr(parent, attr_name, rendered_content)
129
130 def _render(self, template, context, dag=None) -> Any:
131 if dag and dag.render_template_as_native_obj:
132 return render_template_as_native(template, context)
133 return render_template_to_string(template, context)
134
135 def render_template(
136 self,
137 content: Any,
138 context: Context,
139 jinja_env: jinja2.Environment | None = None,
140 seen_oids: set[int] | None = None,
141 ) -> Any:
142 """
143 Render a templated string.
144
145 If *content* is a collection holding multiple templated strings, strings
146 in the collection will be templated recursively.
147
148 :param content: Content to template. Only strings can be templated (may
149 be inside a collection).
150 :param context: Dict with values to apply on templated content
151 :param jinja_env: Jinja environment. Can be provided to avoid
152 re-creating Jinja environments during recursion.
153 :param seen_oids: template fields already rendered (to avoid
154 *RecursionError* on circular dependencies)
155 :return: Templated content
156 """
157 # "content" is a bad name, but we're stuck to it being public API.
158 value = content
159 del content
160
161 if seen_oids is not None:
162 oids = seen_oids
163 else:
164 oids = set()
165
166 if id(value) in oids:
167 return value
168
169 if not jinja_env:
170 jinja_env = self.get_template_env()
171
172 if isinstance(value, str):
173 if value.endswith(tuple(self.template_ext)): # A filepath.
174 template = jinja_env.get_template(value)
175 else:
176 template = jinja_env.from_string(value)
177 return self._render(template, context)
178 if isinstance(value, ObjectStoragePath):
179 return self._render_object_storage_path(value, context, jinja_env)
180
181 if resolve := getattr(value, "resolve", None):
182 return resolve(context)
183
184 # Fast path for common built-in collections.
185 if value.__class__ is tuple:
186 return tuple(self.render_template(element, context, jinja_env, oids) for element in value)
187 if isinstance(value, tuple): # Special case for named tuples.
188 return value.__class__(*(self.render_template(el, context, jinja_env, oids) for el in value))
189 if isinstance(value, list):
190 return [self.render_template(element, context, jinja_env, oids) for element in value]
191 if isinstance(value, dict):
192 return {k: self.render_template(v, context, jinja_env, oids) for k, v in value.items()}
193 if isinstance(value, set):
194 return {self.render_template(element, context, jinja_env, oids) for element in value}
195
196 # More complex collections.
197 self._render_nested_template_fields(value, context, jinja_env, oids)
198 return value
199
200 def _render_object_storage_path(
201 self, value: ObjectStoragePath, context: Context, jinja_env: jinja2.Environment
202 ) -> ObjectStoragePath:
203 serialized_path = value.serialize()
204 path_version = value.__version__
205 serialized_path["path"] = self._render(jinja_env.from_string(serialized_path["path"]), context)
206 return value.deserialize(data=serialized_path, version=path_version)
207
208 def _render_nested_template_fields(
209 self,
210 value: Any,
211 context: Context,
212 jinja_env: jinja2.Environment,
213 seen_oids: set[int],
214 ) -> None:
215 if id(value) in seen_oids:
216 return
217 seen_oids.add(id(value))
218 try:
219 nested_template_fields = value.template_fields
220 except AttributeError:
221 # content has no inner template fields
222 return
223 self._do_render_template_fields(value, nested_template_fields, context, jinja_env, seen_oids)
224
225
226class _AirflowEnvironmentMixin:
227 def __init__(self, **kwargs):
228 super().__init__(**kwargs)
229
230 self.filters.update(FILTERS)
231
232 def is_safe_attribute(self, obj, attr, value):
233 """
234 Allow access to ``_`` prefix vars (but not ``__``).
235
236 Unlike the stock SandboxedEnvironment, we allow access to "private" attributes (ones starting with
237 ``_``) whilst still blocking internal or truly private attributes (``__`` prefixed ones).
238 """
239 return not jinja2.sandbox.is_internal_attribute(obj, attr)
240
241
242class NativeEnvironment(_AirflowEnvironmentMixin, jinja2.nativetypes.NativeEnvironment):
243 """NativeEnvironment for Airflow task templates."""
244
245
246class SandboxedEnvironment(_AirflowEnvironmentMixin, jinja2.sandbox.SandboxedEnvironment):
247 """SandboxedEnvironment for Airflow task templates."""
248
249
250def ds_filter(value: datetime.date | datetime.time | None) -> str | None:
251 """Date filter."""
252 if value is None:
253 return None
254 return value.strftime("%Y-%m-%d")
255
256
257def ds_nodash_filter(value: datetime.date | datetime.time | None) -> str | None:
258 """Date filter without dashes."""
259 if value is None:
260 return None
261 return value.strftime("%Y%m%d")
262
263
264def ts_filter(value: datetime.date | datetime.time | None) -> str | None:
265 """Timestamp filter."""
266 if value is None:
267 return None
268 return value.isoformat()
269
270
271def ts_nodash_filter(value: datetime.date | datetime.time | None) -> str | None:
272 """Timestamp filter without dashes."""
273 if value is None:
274 return None
275 return value.strftime("%Y%m%dT%H%M%S")
276
277
278def ts_nodash_with_tz_filter(value: datetime.date | datetime.time | None) -> str | None:
279 """Timestamp filter with timezone."""
280 if value is None:
281 return None
282 return value.isoformat().replace("-", "").replace(":", "")
283
284
285FILTERS = {
286 "ds": ds_filter,
287 "ds_nodash": ds_nodash_filter,
288 "ts": ts_filter,
289 "ts_nodash": ts_nodash_filter,
290 "ts_nodash_with_tz": ts_nodash_with_tz_filter,
291}