Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/flask_wtf/csrf.py: 37%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import hashlib
2import hmac
3import logging
4import os
5from urllib.parse import urlparse
7from flask import Blueprint
8from flask import current_app
9from flask import g
10from flask import request
11from flask import session
12from itsdangerous import BadData
13from itsdangerous import SignatureExpired
14from itsdangerous import URLSafeTimedSerializer
15from werkzeug.exceptions import BadRequest
16from wtforms import ValidationError
17from wtforms.csrf.core import CSRF
19__all__ = ("generate_csrf", "validate_csrf", "CSRFProtect")
20logger = logging.getLogger(__name__)
23def generate_csrf(secret_key=None, token_key=None):
24 """Generate a CSRF token. The token is cached for a request, so multiple
25 calls to this function will generate the same token.
27 During testing, it might be useful to access the signed token in
28 ``g.csrf_token`` and the raw token in ``session['csrf_token']``.
30 :param secret_key: Used to securely sign the token. Default is
31 ``WTF_CSRF_SECRET_KEY`` or ``SECRET_KEY``.
32 :param token_key: Key where token is stored in session for comparison.
33 Default is ``WTF_CSRF_FIELD_NAME`` or ``'csrf_token'``.
34 """
36 secret_key = _get_config(
37 secret_key,
38 "WTF_CSRF_SECRET_KEY",
39 current_app.secret_key,
40 message="A secret key is required to use CSRF.",
41 )
42 field_name = _get_config(
43 token_key,
44 "WTF_CSRF_FIELD_NAME",
45 "csrf_token",
46 message="A field name is required to use CSRF.",
47 )
49 if field_name not in g:
50 s = URLSafeTimedSerializer(secret_key, salt="wtf-csrf-token")
52 if field_name not in session:
53 session[field_name] = hashlib.sha1(os.urandom(64)).hexdigest()
55 try:
56 token = s.dumps(session[field_name])
57 except TypeError:
58 session[field_name] = hashlib.sha1(os.urandom(64)).hexdigest()
59 token = s.dumps(session[field_name])
61 setattr(g, field_name, token)
63 return g.get(field_name)
66def validate_csrf(data, secret_key=None, time_limit=None, token_key=None):
67 """Check if the given data is a valid CSRF token. This compares the given
68 signed token to the one stored in the session.
70 :param data: The signed CSRF token to be checked.
71 :param secret_key: Used to securely sign the token. Default is
72 ``WTF_CSRF_SECRET_KEY`` or ``SECRET_KEY``.
73 :param time_limit: Number of seconds that the token is valid. Default is
74 ``WTF_CSRF_TIME_LIMIT`` or 3600 seconds (60 minutes).
75 :param token_key: Key where token is stored in session for comparison.
76 Default is ``WTF_CSRF_FIELD_NAME`` or ``'csrf_token'``.
78 :raises ValidationError: Contains the reason that validation failed.
80 .. versionchanged:: 0.14
81 Raises ``ValidationError`` with a specific error message rather than
82 returning ``True`` or ``False``.
83 """
85 secret_key = _get_config(
86 secret_key,
87 "WTF_CSRF_SECRET_KEY",
88 current_app.secret_key,
89 message="A secret key is required to use CSRF.",
90 )
91 field_name = _get_config(
92 token_key,
93 "WTF_CSRF_FIELD_NAME",
94 "csrf_token",
95 message="A field name is required to use CSRF.",
96 )
97 time_limit = _get_config(time_limit, "WTF_CSRF_TIME_LIMIT", 3600, required=False)
99 if not data:
100 raise ValidationError("The CSRF token is missing.")
102 if field_name not in session:
103 raise ValidationError("The CSRF session token is missing.")
105 s = URLSafeTimedSerializer(secret_key, salt="wtf-csrf-token")
107 try:
108 token = s.loads(data, max_age=time_limit)
109 except SignatureExpired as e:
110 raise ValidationError("The CSRF token has expired.") from e
111 except BadData as e:
112 raise ValidationError("The CSRF token is invalid.") from e
114 if not hmac.compare_digest(session[field_name], token):
115 raise ValidationError("The CSRF tokens do not match.")
118def _get_config(
119 value, config_name, default=None, required=True, message="CSRF is not configured."
120):
121 """Find config value based on provided value, Flask config, and default
122 value.
124 :param value: already provided config value
125 :param config_name: Flask ``config`` key
126 :param default: default value if not provided or configured
127 :param required: whether the value must not be ``None``
128 :param message: error message if required config is not found
129 :raises KeyError: if required config is not found
130 """
132 if value is None:
133 value = current_app.config.get(config_name, default)
135 if required and value is None:
136 raise RuntimeError(message)
138 return value
141class _FlaskFormCSRF(CSRF):
142 def setup_form(self, form):
143 self.meta = form.meta
144 return super().setup_form(form)
146 def generate_csrf_token(self, csrf_token_field):
147 return generate_csrf(
148 secret_key=self.meta.csrf_secret, token_key=self.meta.csrf_field_name
149 )
151 def validate_csrf_token(self, form, field):
152 if g.get("csrf_valid", False):
153 # already validated by CSRFProtect
154 return
156 try:
157 validate_csrf(
158 field.data,
159 self.meta.csrf_secret,
160 self.meta.csrf_time_limit,
161 self.meta.csrf_field_name,
162 )
163 except ValidationError as e:
164 logger.info(e.args[0])
165 raise
168class CSRFProtect:
169 """Enable CSRF protection globally for a Flask app.
171 ::
173 app = Flask(__name__)
174 csrf = CSRFProtect(app)
176 Checks the ``csrf_token`` field sent with forms, or the ``X-CSRFToken``
177 header sent with JavaScript requests. Render the token in templates using
178 ``{{ csrf_token() }}``.
180 See the :ref:`csrf` documentation.
181 """
183 def __init__(self, app=None):
184 self._exempt_views = set()
185 self._exempt_blueprints = set()
187 if app:
188 self.init_app(app)
190 def init_app(self, app):
191 app.extensions["csrf"] = self
193 app.config.setdefault("WTF_CSRF_ENABLED", True)
194 app.config.setdefault("WTF_CSRF_CHECK_DEFAULT", True)
195 app.config["WTF_CSRF_METHODS"] = set(
196 app.config.get("WTF_CSRF_METHODS", ["POST", "PUT", "PATCH", "DELETE"])
197 )
198 app.config.setdefault("WTF_CSRF_FIELD_NAME", "csrf_token")
199 app.config.setdefault("WTF_CSRF_HEADERS", ["X-CSRFToken", "X-CSRF-Token"])
200 app.config.setdefault("WTF_CSRF_TIME_LIMIT", 3600)
201 app.config.setdefault("WTF_CSRF_SSL_STRICT", True)
203 app.jinja_env.globals["csrf_token"] = generate_csrf
204 app.context_processor(lambda: {"csrf_token": generate_csrf})
206 @app.before_request
207 def csrf_protect():
208 if not app.config["WTF_CSRF_ENABLED"]:
209 return
211 if not app.config["WTF_CSRF_CHECK_DEFAULT"]:
212 return
214 if request.method not in app.config["WTF_CSRF_METHODS"]:
215 return
217 if not request.endpoint:
218 return
220 if app.blueprints.get(request.blueprint) in self._exempt_blueprints:
221 return
223 view = app.view_functions.get(request.endpoint)
224 dest = f"{view.__module__}.{view.__name__}"
226 if dest in self._exempt_views:
227 return
229 self.protect()
231 def _get_csrf_token(self):
232 # find the token in the form data
233 field_name = current_app.config["WTF_CSRF_FIELD_NAME"]
234 base_token = request.form.get(field_name)
236 if base_token:
237 return base_token
239 # if the form has a prefix, the name will be {prefix}-csrf_token
240 for key in request.form:
241 if key.endswith(field_name):
242 csrf_token = request.form[key]
244 if csrf_token:
245 return csrf_token
247 # find the token in the headers
248 for header_name in current_app.config["WTF_CSRF_HEADERS"]:
249 csrf_token = request.headers.get(header_name)
251 if csrf_token:
252 return csrf_token
254 return None
256 def protect(self):
257 if request.method not in current_app.config["WTF_CSRF_METHODS"]:
258 return
260 try:
261 validate_csrf(self._get_csrf_token())
262 except ValidationError as e:
263 logger.info(e.args[0])
264 self._error_response(e.args[0])
266 if request.is_secure and current_app.config["WTF_CSRF_SSL_STRICT"]:
267 if not request.referrer:
268 self._error_response("The referrer header is missing.")
270 good_referrer = f"https://{request.host}/"
272 if not same_origin(request.referrer, good_referrer):
273 self._error_response("The referrer does not match the host.")
275 g.csrf_valid = True # mark this request as CSRF valid
277 def exempt(self, view):
278 """Mark a view or blueprint to be excluded from CSRF protection.
280 ::
282 @app.route('/some-view', methods=['POST'])
283 @csrf.exempt
284 def some_view():
285 ...
287 ::
289 bp = Blueprint(...)
290 csrf.exempt(bp)
292 """
294 if isinstance(view, Blueprint):
295 self._exempt_blueprints.add(view)
296 return view
298 if isinstance(view, str):
299 view_location = view
300 else:
301 view_location = ".".join((view.__module__, view.__name__))
303 self._exempt_views.add(view_location)
304 return view
306 def _error_response(self, reason):
307 raise CSRFError(reason)
310class CSRFError(BadRequest):
311 """Raise if the client sends invalid CSRF data with the request.
313 Generates a 400 Bad Request response with the failure reason by default.
314 Customize the response by registering a handler with
315 :meth:`flask.Flask.errorhandler`.
316 """
318 description = "CSRF validation failed."
321def same_origin(current_uri, compare_uri):
322 current = urlparse(current_uri)
323 compare = urlparse(compare_uri)
325 return (
326 current.scheme == compare.scheme
327 and current.hostname == compare.hostname
328 and current.port == compare.port
329 )