1from __future__ import annotations
2
3import typing
4import warnings
5from os import PathLike
6
7from starlette.background import BackgroundTask
8from starlette.datastructures import URL
9from starlette.requests import Request
10from starlette.responses import HTMLResponse
11from starlette.types import Receive, Scope, Send
12
13try:
14 import jinja2
15
16 # @contextfunction was renamed to @pass_context in Jinja 3.0, and was removed in 3.1
17 # hence we try to get pass_context (most installs will be >=3.1)
18 # and fall back to contextfunction,
19 # adding a type ignore for mypy to let us access an attribute that may not exist
20 if hasattr(jinja2, "pass_context"):
21 pass_context = jinja2.pass_context
22 else: # pragma: nocover
23 pass_context = jinja2.contextfunction # type: ignore[attr-defined]
24except ModuleNotFoundError: # pragma: nocover
25 jinja2 = None # type: ignore[assignment]
26
27
28class _TemplateResponse(HTMLResponse):
29 def __init__(
30 self,
31 template: typing.Any,
32 context: dict[str, typing.Any],
33 status_code: int = 200,
34 headers: typing.Mapping[str, str] | None = None,
35 media_type: str | None = None,
36 background: BackgroundTask | None = None,
37 ):
38 self.template = template
39 self.context = context
40 content = template.render(context)
41 super().__init__(content, status_code, headers, media_type, background)
42
43 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
44 request = self.context.get("request", {})
45 extensions = request.get("extensions", {})
46 if "http.response.debug" in extensions:
47 await send(
48 {
49 "type": "http.response.debug",
50 "info": {
51 "template": self.template,
52 "context": self.context,
53 },
54 }
55 )
56 await super().__call__(scope, receive, send)
57
58
59class Jinja2Templates:
60 """
61 templates = Jinja2Templates("templates")
62
63 return templates.TemplateResponse("index.html", {"request": request})
64 """
65
66 @typing.overload
67 def __init__(
68 self,
69 directory: str
70 | PathLike[typing.AnyStr]
71 | typing.Sequence[str | PathLike[typing.AnyStr]],
72 *,
73 context_processors: list[typing.Callable[[Request], dict[str, typing.Any]]]
74 | None = None,
75 **env_options: typing.Any,
76 ) -> None:
77 ...
78
79 @typing.overload
80 def __init__(
81 self,
82 *,
83 env: jinja2.Environment,
84 context_processors: list[typing.Callable[[Request], dict[str, typing.Any]]]
85 | None = None,
86 ) -> None:
87 ...
88
89 def __init__(
90 self,
91 directory: str
92 | PathLike[typing.AnyStr]
93 | typing.Sequence[str | PathLike[typing.AnyStr]]
94 | None = None,
95 *,
96 context_processors: list[typing.Callable[[Request], dict[str, typing.Any]]]
97 | None = None,
98 env: jinja2.Environment | None = None,
99 **env_options: typing.Any,
100 ) -> None:
101 if env_options:
102 warnings.warn(
103 "Extra environment options are deprecated. Use a preconfigured jinja2.Environment instead.", # noqa: E501
104 DeprecationWarning,
105 )
106 assert jinja2 is not None, "jinja2 must be installed to use Jinja2Templates"
107 assert directory or env, "either 'directory' or 'env' arguments must be passed"
108 self.context_processors = context_processors or []
109 if directory is not None:
110 self.env = self._create_env(directory, **env_options)
111 elif env is not None:
112 self.env = env
113
114 self._setup_env_defaults(self.env)
115
116 def _create_env(
117 self,
118 directory: str
119 | PathLike[typing.AnyStr]
120 | typing.Sequence[str | PathLike[typing.AnyStr]],
121 **env_options: typing.Any,
122 ) -> jinja2.Environment:
123 loader = jinja2.FileSystemLoader(directory)
124 env_options.setdefault("loader", loader)
125 env_options.setdefault("autoescape", True)
126
127 return jinja2.Environment(**env_options)
128
129 def _setup_env_defaults(self, env: jinja2.Environment) -> None:
130 @pass_context
131 def url_for(
132 context: dict[str, typing.Any],
133 name: str,
134 /,
135 **path_params: typing.Any,
136 ) -> URL:
137 request: Request = context["request"]
138 return request.url_for(name, **path_params)
139
140 env.globals.setdefault("url_for", url_for)
141
142 def get_template(self, name: str) -> jinja2.Template:
143 return self.env.get_template(name)
144
145 @typing.overload
146 def TemplateResponse(
147 self,
148 request: Request,
149 name: str,
150 context: dict[str, typing.Any] | None = None,
151 status_code: int = 200,
152 headers: typing.Mapping[str, str] | None = None,
153 media_type: str | None = None,
154 background: BackgroundTask | None = None,
155 ) -> _TemplateResponse:
156 ...
157
158 @typing.overload
159 def TemplateResponse(
160 self,
161 name: str,
162 context: dict[str, typing.Any] | None = None,
163 status_code: int = 200,
164 headers: typing.Mapping[str, str] | None = None,
165 media_type: str | None = None,
166 background: BackgroundTask | None = None,
167 ) -> _TemplateResponse:
168 # Deprecated usage
169 ...
170
171 def TemplateResponse(
172 self, *args: typing.Any, **kwargs: typing.Any
173 ) -> _TemplateResponse:
174 if args:
175 if isinstance(
176 args[0], str
177 ): # the first argument is template name (old style)
178 warnings.warn(
179 "The `name` is not the first parameter anymore. "
180 "The first parameter should be the `Request` instance.\n"
181 'Replace `TemplateResponse(name, {"request": request})` by `TemplateResponse(request, name)`.', # noqa: E501
182 DeprecationWarning,
183 )
184
185 name = args[0]
186 context = args[1] if len(args) > 1 else kwargs.get("context", {})
187 status_code = (
188 args[2] if len(args) > 2 else kwargs.get("status_code", 200)
189 )
190 headers = args[2] if len(args) > 2 else kwargs.get("headers")
191 media_type = args[3] if len(args) > 3 else kwargs.get("media_type")
192 background = args[4] if len(args) > 4 else kwargs.get("background")
193
194 if "request" not in context:
195 raise ValueError('context must include a "request" key')
196 request = context["request"]
197 else: # the first argument is a request instance (new style)
198 request = args[0]
199 name = args[1] if len(args) > 1 else kwargs["name"]
200 context = args[2] if len(args) > 2 else kwargs.get("context", {})
201 status_code = (
202 args[3] if len(args) > 3 else kwargs.get("status_code", 200)
203 )
204 headers = args[4] if len(args) > 4 else kwargs.get("headers")
205 media_type = args[5] if len(args) > 5 else kwargs.get("media_type")
206 background = args[6] if len(args) > 6 else kwargs.get("background")
207 else: # all arguments are kwargs
208 if "request" not in kwargs:
209 warnings.warn(
210 "The `TemplateResponse` now requires the `request` argument.\n"
211 'Replace `TemplateResponse(name, {"context": context})` by `TemplateResponse(request, name)`.', # noqa: E501
212 DeprecationWarning,
213 )
214 if "request" not in kwargs.get("context", {}):
215 raise ValueError('context must include a "request" key')
216
217 context = kwargs.get("context", {})
218 request = kwargs.get("request", context.get("request"))
219 name = typing.cast(str, kwargs["name"])
220 status_code = kwargs.get("status_code", 200)
221 headers = kwargs.get("headers")
222 media_type = kwargs.get("media_type")
223 background = kwargs.get("background")
224
225 context.setdefault("request", request)
226 for context_processor in self.context_processors:
227 context.update(context_processor(request))
228
229 template = self.get_template(name)
230 return _TemplateResponse(
231 template,
232 context,
233 status_code=status_code,
234 headers=headers,
235 media_type=media_type,
236 background=background,
237 )