Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/flask_wtf/csrf.py: 37%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import hashlib
2import hmac
3import logging
4import os
5from urllib.parse import urlparse
7from flask import Blueprint
8from flask import current_app
9from flask import g
10from flask import request
11from flask import session
12from itsdangerous import BadData
13from itsdangerous import SignatureExpired
14from itsdangerous import URLSafeTimedSerializer
15from markupsafe import escape
16from markupsafe import Markup
17from werkzeug.exceptions import BadRequest
18from wtforms import ValidationError
19from wtforms.csrf.core import CSRF
21__all__ = ("generate_csrf", "validate_csrf", "csrf_meta_tag", "CSRFProtect")
22logger = logging.getLogger(__name__)
25def generate_csrf(secret_key=None, token_key=None):
26 """Generate a CSRF token. The token is cached for a request, so multiple
27 calls to this function will generate the same token.
29 During testing, it might be useful to access the signed token in
30 ``g.csrf_token`` and the raw token in ``session['csrf_token']``.
32 :param secret_key: Used to securely sign the token. Default is
33 ``WTF_CSRF_SECRET_KEY`` or ``SECRET_KEY``.
34 :param token_key: Key where token is stored in session for comparison.
35 Default is ``WTF_CSRF_FIELD_NAME`` or ``'csrf_token'``.
36 """
38 secret_key = _get_config(
39 secret_key,
40 "WTF_CSRF_SECRET_KEY",
41 current_app.secret_key,
42 message="A secret key is required to use CSRF.",
43 )
44 field_name = _get_config(
45 token_key,
46 "WTF_CSRF_FIELD_NAME",
47 "csrf_token",
48 message="A field name is required to use CSRF.",
49 )
51 if field_name not in g:
52 s = URLSafeTimedSerializer(secret_key, salt="wtf-csrf-token")
54 if field_name not in session:
55 session[field_name] = hashlib.sha1(os.urandom(64)).hexdigest()
57 try:
58 token = s.dumps(session[field_name])
59 except TypeError:
60 session[field_name] = hashlib.sha1(os.urandom(64)).hexdigest()
61 token = s.dumps(session[field_name])
63 setattr(g, field_name, token)
65 return g.get(field_name)
68def validate_csrf(data, secret_key=None, time_limit=None, token_key=None):
69 """Check if the given data is a valid CSRF token. This compares the given
70 signed token to the one stored in the session.
72 :param data: The signed CSRF token to be checked.
73 :param secret_key: Used to securely sign the token. Default is
74 ``WTF_CSRF_SECRET_KEY`` or ``SECRET_KEY``.
75 :param time_limit: Number of seconds that the token is valid. Default is
76 ``WTF_CSRF_TIME_LIMIT`` or 3600 seconds (60 minutes).
77 :param token_key: Key where token is stored in session for comparison.
78 Default is ``WTF_CSRF_FIELD_NAME`` or ``'csrf_token'``.
80 :raises ValidationError: Contains the reason that validation failed.
82 .. versionchanged:: 0.14
83 Raises ``ValidationError`` with a specific error message rather than
84 returning ``True`` or ``False``.
85 """
87 secret_key = _get_config(
88 secret_key,
89 "WTF_CSRF_SECRET_KEY",
90 current_app.secret_key,
91 message="A secret key is required to use CSRF.",
92 )
93 field_name = _get_config(
94 token_key,
95 "WTF_CSRF_FIELD_NAME",
96 "csrf_token",
97 message="A field name is required to use CSRF.",
98 )
99 time_limit = _get_config(time_limit, "WTF_CSRF_TIME_LIMIT", 3600, required=False)
101 if not data:
102 raise ValidationError("The CSRF token is missing.")
104 if field_name not in session:
105 raise ValidationError("The CSRF session token is missing.")
107 s = URLSafeTimedSerializer(secret_key, salt="wtf-csrf-token")
109 try:
110 token = s.loads(data, max_age=time_limit)
111 except SignatureExpired as e:
112 raise ValidationError("The CSRF token has expired.") from e
113 except BadData as e:
114 raise ValidationError("The CSRF token is invalid.") from e
116 if not hmac.compare_digest(session[field_name], token):
117 raise ValidationError("The CSRF tokens do not match.")
120def csrf_meta_tag(name=None, secret_key=None, token_key=None):
121 """Render an HTML ``<meta>`` tag carrying the CSRF token, following the
122 convention used by Rails and recommended by OWASP for SPA and AJAX clients.
124 Extract the token client-side with
125 ``document.querySelector('meta[name="csrf-token"]').content`` and send it
126 in the ``X-CSRFToken`` header of state-changing requests.
128 :param name: Value of the meta tag's ``name`` attribute. Default is
129 ``WTF_CSRF_META_NAME`` or ``'csrf-token'``.
130 :param secret_key: Forwarded to :func:`generate_csrf`.
131 :param token_key: Forwarded to :func:`generate_csrf`.
132 """
134 name = _get_config(name, "WTF_CSRF_META_NAME", "csrf-token")
135 token = generate_csrf(secret_key=secret_key, token_key=token_key)
136 return Markup(f'<meta name="{escape(name)}" content="{escape(token)}">')
139def _get_config(
140 value, config_name, default=None, required=True, message="CSRF is not configured."
141):
142 """Find config value based on provided value, Flask config, and default
143 value.
145 :param value: already provided config value
146 :param config_name: Flask ``config`` key
147 :param default: default value if not provided or configured
148 :param required: whether the value must not be ``None``
149 :param message: error message if required config is not found
150 :raises KeyError: if required config is not found
151 """
153 if value is None:
154 value = current_app.config.get(config_name, default)
156 if required and value is None:
157 raise RuntimeError(message)
159 return value
162class _FlaskFormCSRF(CSRF):
163 def setup_form(self, form):
164 self.meta = form.meta
165 return super().setup_form(form)
167 def generate_csrf_token(self, csrf_token_field):
168 return generate_csrf(
169 secret_key=self.meta.csrf_secret, token_key=self.meta.csrf_field_name
170 )
172 def validate_csrf_token(self, form, field):
173 if g.get("csrf_valid", False):
174 # already validated by CSRFProtect
175 return
177 try:
178 validate_csrf(
179 field.data,
180 self.meta.csrf_secret,
181 self.meta.csrf_time_limit,
182 self.meta.csrf_field_name,
183 )
184 except ValidationError as e:
185 logger.info(e.args[0])
186 raise
189class CSRFProtect:
190 """Enable CSRF protection globally for a Flask app.
192 ::
194 app = Flask(__name__)
195 csrf = CSRFProtect(app)
197 Checks the ``csrf_token`` field sent with forms, or the ``X-CSRFToken``
198 header sent with JavaScript requests. Render the token in templates using
199 ``{{ csrf_token() }}``.
201 See the :ref:`csrf` documentation.
202 """
204 def __init__(self, app=None):
205 self._exempt_views = set()
206 self._exempt_blueprints = set()
208 if app:
209 self.init_app(app)
211 def init_app(self, app):
212 app.extensions["csrf"] = self
214 app.config.setdefault("WTF_CSRF_ENABLED", True)
215 app.config.setdefault("WTF_CSRF_CHECK_DEFAULT", True)
216 app.config["WTF_CSRF_METHODS"] = set(
217 app.config.get("WTF_CSRF_METHODS", ["POST", "PUT", "PATCH", "DELETE"])
218 )
219 app.config.setdefault("WTF_CSRF_FIELD_NAME", "csrf_token")
220 app.config.setdefault("WTF_CSRF_HEADERS", ["X-CSRFToken", "X-CSRF-Token"])
221 app.config.setdefault("WTF_CSRF_META_NAME", "csrf-token")
222 app.config.setdefault("WTF_CSRF_TIME_LIMIT", 3600)
223 app.config.setdefault("WTF_CSRF_SSL_STRICT", True)
225 app.jinja_env.globals["csrf_token"] = generate_csrf
226 app.jinja_env.globals["csrf_meta_tag"] = csrf_meta_tag
227 app.context_processor(
228 lambda: {"csrf_token": generate_csrf, "csrf_meta_tag": csrf_meta_tag}
229 )
231 @app.before_request
232 def csrf_protect():
233 if not app.config["WTF_CSRF_ENABLED"]:
234 return
236 if not app.config["WTF_CSRF_CHECK_DEFAULT"]:
237 return
239 self.protect(apply_exemptions=True)
241 def _get_csrf_token(self):
242 # find the token in the form data
243 field_name = current_app.config["WTF_CSRF_FIELD_NAME"]
244 base_token = request.form.get(field_name)
246 if base_token:
247 return base_token
249 # if the form has a prefix, the name will be {prefix}-csrf_token
250 for key in request.form:
251 if key.endswith(field_name):
252 csrf_token = request.form[key]
254 if csrf_token:
255 return csrf_token
257 # find the token in the headers
258 for header_name in current_app.config["WTF_CSRF_HEADERS"]:
259 csrf_token = request.headers.get(header_name)
261 if csrf_token:
262 return csrf_token
264 return None
266 def protect(self, apply_exemptions=False):
267 """Validate CSRF on the current request.
269 When ``apply_exemptions`` is ``True``, views and blueprints marked with
270 :meth:`exempt` are skipped. This lets you combine a custom
271 ``before_request`` hook (or any manual call) with the declarative
272 ``@csrf.exempt`` decorator.
273 """
275 if apply_exemptions:
276 if not request.endpoint:
277 return
279 if self._is_exempt():
280 return
282 if request.method not in current_app.config["WTF_CSRF_METHODS"]:
283 return
285 try:
286 validate_csrf(self._get_csrf_token())
287 except ValidationError as e:
288 logger.info(e.args[0])
289 self._error_response(e.args[0])
291 if request.is_secure and current_app.config["WTF_CSRF_SSL_STRICT"]:
292 if not request.referrer:
293 self._error_response("The referrer header is missing.")
295 good_referrer = f"https://{request.host}/"
297 if not same_origin(request.referrer, good_referrer):
298 self._error_response("The referrer does not match the host.")
300 g.csrf_valid = True # mark this request as CSRF valid
302 def _is_exempt(self):
303 if current_app.blueprints.get(request.blueprint) in self._exempt_blueprints:
304 return True
306 view = current_app.view_functions.get(request.endpoint)
307 if view is None:
308 return False
310 dest = f"{view.__module__}.{view.__name__}"
311 return dest in self._exempt_views
313 def exempt(self, view):
314 """Mark a view or blueprint to be excluded from CSRF protection.
316 ::
318 @app.route('/some-view', methods=['POST'])
319 @csrf.exempt
320 def some_view():
321 ...
323 ::
325 bp = Blueprint(...)
326 csrf.exempt(bp)
328 """
330 if isinstance(view, Blueprint):
331 self._exempt_blueprints.add(view)
332 return view
334 if isinstance(view, str):
335 view_location = view
336 else:
337 view_location = ".".join((view.__module__, view.__name__))
339 self._exempt_views.add(view_location)
340 return view
342 def _error_response(self, reason):
343 raise CSRFError(reason)
346class CSRFError(BadRequest):
347 """Raise if the client sends invalid CSRF data with the request.
349 Generates a 400 Bad Request response with the failure reason by default.
350 Customize the response by registering a handler with
351 :meth:`flask.Flask.errorhandler`.
352 """
354 description = "CSRF validation failed."
357def same_origin(current_uri, compare_uri):
358 current = urlparse(current_uri)
359 compare = urlparse(compare_uri)
361 return (
362 current.scheme == compare.scheme
363 and current.hostname == compare.hostname
364 and current.port == compare.port
365 )