1from __future__ import annotations
2
3import warnings
4from collections.abc import Mapping, Sequence
5from os import PathLike
6from typing import Any, Callable, cast, overload
7
8from starlette.background import BackgroundTask
9from starlette.datastructures import URL
10from starlette.requests import Request
11from starlette.responses import HTMLResponse
12from starlette.types import Receive, Scope, Send
13
14try:
15 import jinja2
16
17 # @contextfunction was renamed to @pass_context in Jinja 3.0, and was removed in 3.1
18 # hence we try to get pass_context (most installs will be >=3.1)
19 # and fall back to contextfunction,
20 # adding a type ignore for mypy to let us access an attribute that may not exist
21 if hasattr(jinja2, "pass_context"):
22 pass_context = jinja2.pass_context
23 else: # pragma: no cover
24 pass_context = jinja2.contextfunction # type: ignore[attr-defined]
25except ModuleNotFoundError: # pragma: no cover
26 jinja2 = None # type: ignore[assignment]
27
28
29class _TemplateResponse(HTMLResponse):
30 def __init__(
31 self,
32 template: Any,
33 context: dict[str, Any],
34 status_code: int = 200,
35 headers: Mapping[str, str] | None = None,
36 media_type: str | None = None,
37 background: BackgroundTask | None = None,
38 ):
39 self.template = template
40 self.context = context
41 content = template.render(context)
42 super().__init__(content, status_code, headers, media_type, background)
43
44 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
45 request = self.context.get("request", {})
46 extensions = request.get("extensions", {})
47 if "http.response.debug" in extensions: # pragma: no branch
48 await send(
49 {
50 "type": "http.response.debug",
51 "info": {
52 "template": self.template,
53 "context": self.context,
54 },
55 }
56 )
57 await super().__call__(scope, receive, send)
58
59
60class Jinja2Templates:
61 """
62 templates = Jinja2Templates("templates")
63
64 return templates.TemplateResponse("index.html", {"request": request})
65 """
66
67 @overload
68 def __init__(
69 self,
70 directory: str | PathLike[str] | Sequence[str | PathLike[str]],
71 *,
72 context_processors: list[Callable[[Request], dict[str, Any]]] | None = None,
73 **env_options: Any,
74 ) -> None: ...
75
76 @overload
77 def __init__(
78 self,
79 *,
80 env: jinja2.Environment,
81 context_processors: list[Callable[[Request], dict[str, Any]]] | None = None,
82 ) -> None: ...
83
84 def __init__(
85 self,
86 directory: str | PathLike[str] | Sequence[str | PathLike[str]] | None = None,
87 *,
88 context_processors: list[Callable[[Request], dict[str, Any]]] | None = None,
89 env: jinja2.Environment | None = None,
90 **env_options: Any,
91 ) -> None:
92 if env_options:
93 warnings.warn(
94 "Extra environment options are deprecated. Use a preconfigured jinja2.Environment instead.",
95 DeprecationWarning,
96 )
97 assert jinja2 is not None, "jinja2 must be installed to use Jinja2Templates"
98 assert bool(directory) ^ bool(env), "either 'directory' or 'env' arguments must be passed"
99 self.context_processors = context_processors or []
100 if directory is not None:
101 self.env = self._create_env(directory, **env_options)
102 elif env is not None: # pragma: no branch
103 self.env = env
104
105 self._setup_env_defaults(self.env)
106
107 def _create_env(
108 self,
109 directory: str | PathLike[str] | Sequence[str | PathLike[str]],
110 **env_options: Any,
111 ) -> jinja2.Environment:
112 loader = jinja2.FileSystemLoader(directory)
113 env_options.setdefault("loader", loader)
114 env_options.setdefault("autoescape", True)
115
116 return jinja2.Environment(**env_options)
117
118 def _setup_env_defaults(self, env: jinja2.Environment) -> None:
119 @pass_context
120 def url_for(
121 context: dict[str, Any],
122 name: str,
123 /,
124 **path_params: Any,
125 ) -> URL:
126 request: Request = context["request"]
127 return request.url_for(name, **path_params)
128
129 env.globals.setdefault("url_for", url_for)
130
131 def get_template(self, name: str) -> jinja2.Template:
132 return self.env.get_template(name)
133
134 @overload
135 def TemplateResponse(
136 self,
137 request: Request,
138 name: str,
139 context: dict[str, Any] | None = None,
140 status_code: int = 200,
141 headers: Mapping[str, str] | None = None,
142 media_type: str | None = None,
143 background: BackgroundTask | None = None,
144 ) -> _TemplateResponse: ...
145
146 @overload
147 def TemplateResponse(
148 self,
149 name: str,
150 context: dict[str, Any] | None = None,
151 status_code: int = 200,
152 headers: Mapping[str, str] | None = None,
153 media_type: str | None = None,
154 background: BackgroundTask | None = None,
155 ) -> _TemplateResponse:
156 # Deprecated usage
157 ...
158
159 def TemplateResponse(self, *args: Any, **kwargs: Any) -> _TemplateResponse:
160 if args:
161 if isinstance(args[0], str): # the first argument is template name (old style)
162 warnings.warn(
163 "The `name` is not the first parameter anymore. "
164 "The first parameter should be the `Request` instance.\n"
165 'Replace `TemplateResponse(name, {"request": request})` by `TemplateResponse(request, name)`.',
166 DeprecationWarning,
167 )
168
169 name = args[0]
170 context = args[1] if len(args) > 1 else kwargs.get("context", {})
171 status_code = args[2] if len(args) > 2 else kwargs.get("status_code", 200)
172 headers = args[3] if len(args) > 3 else kwargs.get("headers")
173 media_type = args[4] if len(args) > 4 else kwargs.get("media_type")
174 background = args[5] if len(args) > 5 else kwargs.get("background")
175
176 if "request" not in context:
177 raise ValueError('context must include a "request" key')
178 request = context["request"]
179 else: # the first argument is a request instance (new style)
180 request = args[0]
181 name = args[1] if len(args) > 1 else kwargs["name"]
182 context = args[2] if len(args) > 2 else kwargs.get("context", {})
183 status_code = args[3] if len(args) > 3 else kwargs.get("status_code", 200)
184 headers = args[4] if len(args) > 4 else kwargs.get("headers")
185 media_type = args[5] if len(args) > 5 else kwargs.get("media_type")
186 background = args[6] if len(args) > 6 else kwargs.get("background")
187 else: # all arguments are kwargs
188 if "request" not in kwargs:
189 warnings.warn(
190 "The `TemplateResponse` now requires the `request` argument.\n"
191 'Replace `TemplateResponse(name, {"context": context})` by `TemplateResponse(request, name)`.',
192 DeprecationWarning,
193 )
194 if "request" not in kwargs.get("context", {}):
195 raise ValueError('context must include a "request" key')
196
197 context = kwargs.get("context", {})
198 request = kwargs.get("request", context.get("request"))
199 name = cast(str, kwargs["name"])
200 status_code = kwargs.get("status_code", 200)
201 headers = kwargs.get("headers")
202 media_type = kwargs.get("media_type")
203 background = kwargs.get("background")
204
205 context.setdefault("request", request)
206 for context_processor in self.context_processors:
207 context.update(context_processor(request))
208
209 template = self.get_template(name)
210 return _TemplateResponse(
211 template,
212 context,
213 status_code=status_code,
214 headers=headers,
215 media_type=media_type,
216 background=background,
217 )