Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/airflow/sdk/definitions/_internal/templater.py: 30%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

136 statements  

1# Licensed to the Apache Software Foundation (ASF) under one 

2# or more contributor license agreements. See the NOTICE file 

3# distributed with this work for additional information 

4# regarding copyright ownership. The ASF licenses this file 

5# to you under the Apache License, Version 2.0 (the 

6# "License"); you may not use this file except in compliance 

7# with the License. You may obtain a copy of the License at 

8# 

9# http://www.apache.org/licenses/LICENSE-2.0 

10# 

11# Unless required by applicable law or agreed to in writing, 

12# software distributed under the License is distributed on an 

13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 

14# KIND, either express or implied. See the License for the 

15# specific language governing permissions and limitations 

16# under the License. 

17 

18from __future__ import annotations 

19 

20import datetime 

21import logging 

22from collections.abc import Collection, Iterable, Sequence 

23from dataclasses import dataclass 

24from typing import TYPE_CHECKING, Any 

25 

26import jinja2 

27import jinja2.nativetypes 

28import jinja2.sandbox 

29 

30from airflow.sdk import ObjectStoragePath 

31from airflow.sdk.definitions._internal.mixins import ResolveMixin 

32from airflow.sdk.definitions.context import render_template_as_native, render_template_to_string 

33 

34if TYPE_CHECKING: 

35 from airflow.sdk.definitions.context import Context 

36 from airflow.sdk.definitions.dag import DAG 

37 from airflow.sdk.types import Operator 

38 

39 

40@dataclass(frozen=True) 

41class LiteralValue(ResolveMixin): 

42 """ 

43 A wrapper for a value that should be rendered as-is, without applying jinja templating to its contents. 

44 

45 :param value: The value to be rendered without templating 

46 """ 

47 

48 value: Any 

49 

50 def iter_references(self) -> Iterable[tuple[Operator, str]]: 

51 return () 

52 

53 def resolve(self, context: Context) -> Any: 

54 return self.value 

55 

56 

57log = logging.getLogger(__name__) 

58 

59 

60class Templater: 

61 """ 

62 This renders the template fields of object. 

63 

64 :meta private: 

65 """ 

66 

67 # For derived classes to define which fields will get jinjaified. 

68 template_fields: Collection[str] 

69 # Defines which files extensions to look for in the templated fields. 

70 template_ext: Sequence[str] 

71 

72 def get_template_env(self, dag: DAG | None = None) -> jinja2.Environment: 

73 """Fetch a Jinja template environment from the Dag or instantiate empty environment if no Dag.""" 

74 # This is imported locally since Jinja2 is heavy and we don't need it 

75 # for most of the functionalities. It is imported by get_template_env() 

76 # though, so we don't need to put this after the 'if dag' check. 

77 

78 if dag: 

79 return dag.get_template_env(force_sandboxed=False) 

80 return SandboxedEnvironment(cache_size=0) 

81 

82 def prepare_template(self) -> None: 

83 """ 

84 Execute after the templated fields get replaced by their content. 

85 

86 If you need your object to alter the content of the file before the 

87 template is rendered, it should override this method to do so. 

88 """ 

89 

90 def resolve_template_files(self) -> None: 

91 """Get the content of files for template_field / template_ext.""" 

92 if self.template_ext: 

93 for field in self.template_fields: 

94 content = getattr(self, field, None) 

95 if isinstance(content, str) and content.endswith(tuple(self.template_ext)): 

96 env = self.get_template_env() 

97 try: 

98 setattr(self, field, env.loader.get_source(env, content)[0]) # type: ignore 

99 except Exception: 

100 log.exception("Failed to resolve template field %r", field) 

101 elif isinstance(content, list): 

102 env = self.get_template_env() 

103 for i, item in enumerate(content): 

104 if isinstance(item, str) and item.endswith(tuple(self.template_ext)): 

105 try: 

106 content[i] = env.loader.get_source(env, item)[0] # type: ignore 

107 except Exception: 

108 log.exception("Failed to get source %s", item) 

109 self.prepare_template() 

110 

111 def _do_render_template_fields( 

112 self, 

113 parent: Any, 

114 template_fields: Iterable[str], 

115 context: Context, 

116 jinja_env: jinja2.Environment, 

117 seen_oids: set[int], 

118 ) -> None: 

119 for attr_name in template_fields: 

120 value = getattr(parent, attr_name) 

121 rendered_content = self.render_template( 

122 value, 

123 context, 

124 jinja_env, 

125 seen_oids, 

126 ) 

127 if rendered_content: 

128 setattr(parent, attr_name, rendered_content) 

129 

130 def _render(self, template, context, dag=None) -> Any: 

131 if dag and dag.render_template_as_native_obj: 

132 return render_template_as_native(template, context) 

133 return render_template_to_string(template, context) 

134 

135 def render_template( 

136 self, 

137 content: Any, 

138 context: Context, 

139 jinja_env: jinja2.Environment | None = None, 

140 seen_oids: set[int] | None = None, 

141 ) -> Any: 

142 """ 

143 Render a templated string. 

144 

145 If *content* is a collection holding multiple templated strings, strings 

146 in the collection will be templated recursively. 

147 

148 :param content: Content to template. Only strings can be templated (may 

149 be inside a collection). 

150 :param context: Dict with values to apply on templated content 

151 :param jinja_env: Jinja environment. Can be provided to avoid 

152 re-creating Jinja environments during recursion. 

153 :param seen_oids: template fields already rendered (to avoid 

154 *RecursionError* on circular dependencies) 

155 :return: Templated content 

156 """ 

157 # "content" is a bad name, but we're stuck to it being public API. 

158 value = content 

159 del content 

160 

161 if seen_oids is not None: 

162 oids = seen_oids 

163 else: 

164 oids = set() 

165 

166 if id(value) in oids: 

167 return value 

168 

169 if not jinja_env: 

170 jinja_env = self.get_template_env() 

171 

172 if isinstance(value, str): 

173 if value.endswith(tuple(self.template_ext)): # A filepath. 

174 template = jinja_env.get_template(value) 

175 else: 

176 template = jinja_env.from_string(value) 

177 return self._render(template, context) 

178 if isinstance(value, ObjectStoragePath): 

179 return self._render_object_storage_path(value, context, jinja_env) 

180 

181 if resolve := getattr(value, "resolve", None): 

182 return resolve(context) 

183 

184 # Fast path for common built-in collections. 

185 if value.__class__ is tuple: 

186 return tuple(self.render_template(element, context, jinja_env, oids) for element in value) 

187 if isinstance(value, tuple): # Special case for named tuples. 

188 return value.__class__(*(self.render_template(el, context, jinja_env, oids) for el in value)) 

189 if isinstance(value, list): 

190 return [self.render_template(element, context, jinja_env, oids) for element in value] 

191 if isinstance(value, dict): 

192 return {k: self.render_template(v, context, jinja_env, oids) for k, v in value.items()} 

193 if isinstance(value, set): 

194 return {self.render_template(element, context, jinja_env, oids) for element in value} 

195 

196 # More complex collections. 

197 self._render_nested_template_fields(value, context, jinja_env, oids) 

198 return value 

199 

200 def _render_object_storage_path( 

201 self, value: ObjectStoragePath, context: Context, jinja_env: jinja2.Environment 

202 ) -> ObjectStoragePath: 

203 serialized_path = value.serialize() 

204 path_version = value.__version__ 

205 serialized_path["path"] = self._render(jinja_env.from_string(serialized_path["path"]), context) 

206 return value.deserialize(data=serialized_path, version=path_version) 

207 

208 def _render_nested_template_fields( 

209 self, 

210 value: Any, 

211 context: Context, 

212 jinja_env: jinja2.Environment, 

213 seen_oids: set[int], 

214 ) -> None: 

215 if id(value) in seen_oids: 

216 return 

217 seen_oids.add(id(value)) 

218 try: 

219 nested_template_fields = value.template_fields 

220 except AttributeError: 

221 # content has no inner template fields 

222 return 

223 self._do_render_template_fields(value, nested_template_fields, context, jinja_env, seen_oids) 

224 

225 

226class _AirflowEnvironmentMixin: 

227 def __init__(self, **kwargs): 

228 super().__init__(**kwargs) 

229 

230 self.filters.update(FILTERS) 

231 

232 def is_safe_attribute(self, obj, attr, value): 

233 """ 

234 Allow access to ``_`` prefix vars (but not ``__``). 

235 

236 Unlike the stock SandboxedEnvironment, we allow access to "private" attributes (ones starting with 

237 ``_``) whilst still blocking internal or truly private attributes (``__`` prefixed ones). 

238 """ 

239 return not jinja2.sandbox.is_internal_attribute(obj, attr) 

240 

241 

242class NativeEnvironment(_AirflowEnvironmentMixin, jinja2.nativetypes.NativeEnvironment): 

243 """NativeEnvironment for Airflow task templates.""" 

244 

245 

246class SandboxedEnvironment(_AirflowEnvironmentMixin, jinja2.sandbox.SandboxedEnvironment): 

247 """SandboxedEnvironment for Airflow task templates.""" 

248 

249 

250def ds_filter(value: datetime.date | datetime.time | None) -> str | None: 

251 """Date filter.""" 

252 if value is None: 

253 return None 

254 return value.strftime("%Y-%m-%d") 

255 

256 

257def ds_nodash_filter(value: datetime.date | datetime.time | None) -> str | None: 

258 """Date filter without dashes.""" 

259 if value is None: 

260 return None 

261 return value.strftime("%Y%m%d") 

262 

263 

264def ts_filter(value: datetime.date | datetime.time | None) -> str | None: 

265 """Timestamp filter.""" 

266 if value is None: 

267 return None 

268 return value.isoformat() 

269 

270 

271def ts_nodash_filter(value: datetime.date | datetime.time | None) -> str | None: 

272 """Timestamp filter without dashes.""" 

273 if value is None: 

274 return None 

275 return value.strftime("%Y%m%dT%H%M%S") 

276 

277 

278def ts_nodash_with_tz_filter(value: datetime.date | datetime.time | None) -> str | None: 

279 """Timestamp filter with timezone.""" 

280 if value is None: 

281 return None 

282 return value.isoformat().replace("-", "").replace(":", "") 

283 

284 

285FILTERS = { 

286 "ds": ds_filter, 

287 "ds_nodash": ds_nodash_filter, 

288 "ts": ts_filter, 

289 "ts_nodash": ts_nodash_filter, 

290 "ts_nodash_with_tz": ts_nodash_with_tz_filter, 

291}