Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/flask_wtf/csrf.py: 37%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

150 statements  

1import hashlib 

2import hmac 

3import logging 

4import os 

5from urllib.parse import urlparse 

6 

7from flask import Blueprint 

8from flask import current_app 

9from flask import g 

10from flask import request 

11from flask import session 

12from itsdangerous import BadData 

13from itsdangerous import SignatureExpired 

14from itsdangerous import URLSafeTimedSerializer 

15from werkzeug.exceptions import BadRequest 

16from wtforms import ValidationError 

17from wtforms.csrf.core import CSRF 

18 

19__all__ = ("generate_csrf", "validate_csrf", "CSRFProtect") 

20logger = logging.getLogger(__name__) 

21 

22 

23def generate_csrf(secret_key=None, token_key=None): 

24 """Generate a CSRF token. The token is cached for a request, so multiple 

25 calls to this function will generate the same token. 

26 

27 During testing, it might be useful to access the signed token in 

28 ``g.csrf_token`` and the raw token in ``session['csrf_token']``. 

29 

30 :param secret_key: Used to securely sign the token. Default is 

31 ``WTF_CSRF_SECRET_KEY`` or ``SECRET_KEY``. 

32 :param token_key: Key where token is stored in session for comparison. 

33 Default is ``WTF_CSRF_FIELD_NAME`` or ``'csrf_token'``. 

34 """ 

35 

36 secret_key = _get_config( 

37 secret_key, 

38 "WTF_CSRF_SECRET_KEY", 

39 current_app.secret_key, 

40 message="A secret key is required to use CSRF.", 

41 ) 

42 field_name = _get_config( 

43 token_key, 

44 "WTF_CSRF_FIELD_NAME", 

45 "csrf_token", 

46 message="A field name is required to use CSRF.", 

47 ) 

48 

49 if field_name not in g: 

50 s = URLSafeTimedSerializer(secret_key, salt="wtf-csrf-token") 

51 

52 if field_name not in session: 

53 session[field_name] = hashlib.sha1(os.urandom(64)).hexdigest() 

54 

55 try: 

56 token = s.dumps(session[field_name]) 

57 except TypeError: 

58 session[field_name] = hashlib.sha1(os.urandom(64)).hexdigest() 

59 token = s.dumps(session[field_name]) 

60 

61 setattr(g, field_name, token) 

62 

63 return g.get(field_name) 

64 

65 

66def validate_csrf(data, secret_key=None, time_limit=None, token_key=None): 

67 """Check if the given data is a valid CSRF token. This compares the given 

68 signed token to the one stored in the session. 

69 

70 :param data: The signed CSRF token to be checked. 

71 :param secret_key: Used to securely sign the token. Default is 

72 ``WTF_CSRF_SECRET_KEY`` or ``SECRET_KEY``. 

73 :param time_limit: Number of seconds that the token is valid. Default is 

74 ``WTF_CSRF_TIME_LIMIT`` or 3600 seconds (60 minutes). 

75 :param token_key: Key where token is stored in session for comparison. 

76 Default is ``WTF_CSRF_FIELD_NAME`` or ``'csrf_token'``. 

77 

78 :raises ValidationError: Contains the reason that validation failed. 

79 

80 .. versionchanged:: 0.14 

81 Raises ``ValidationError`` with a specific error message rather than 

82 returning ``True`` or ``False``. 

83 """ 

84 

85 secret_key = _get_config( 

86 secret_key, 

87 "WTF_CSRF_SECRET_KEY", 

88 current_app.secret_key, 

89 message="A secret key is required to use CSRF.", 

90 ) 

91 field_name = _get_config( 

92 token_key, 

93 "WTF_CSRF_FIELD_NAME", 

94 "csrf_token", 

95 message="A field name is required to use CSRF.", 

96 ) 

97 time_limit = _get_config(time_limit, "WTF_CSRF_TIME_LIMIT", 3600, required=False) 

98 

99 if not data: 

100 raise ValidationError("The CSRF token is missing.") 

101 

102 if field_name not in session: 

103 raise ValidationError("The CSRF session token is missing.") 

104 

105 s = URLSafeTimedSerializer(secret_key, salt="wtf-csrf-token") 

106 

107 try: 

108 token = s.loads(data, max_age=time_limit) 

109 except SignatureExpired as e: 

110 raise ValidationError("The CSRF token has expired.") from e 

111 except BadData as e: 

112 raise ValidationError("The CSRF token is invalid.") from e 

113 

114 if not hmac.compare_digest(session[field_name], token): 

115 raise ValidationError("The CSRF tokens do not match.") 

116 

117 

118def _get_config( 

119 value, config_name, default=None, required=True, message="CSRF is not configured." 

120): 

121 """Find config value based on provided value, Flask config, and default 

122 value. 

123 

124 :param value: already provided config value 

125 :param config_name: Flask ``config`` key 

126 :param default: default value if not provided or configured 

127 :param required: whether the value must not be ``None`` 

128 :param message: error message if required config is not found 

129 :raises KeyError: if required config is not found 

130 """ 

131 

132 if value is None: 

133 value = current_app.config.get(config_name, default) 

134 

135 if required and value is None: 

136 raise RuntimeError(message) 

137 

138 return value 

139 

140 

141class _FlaskFormCSRF(CSRF): 

142 def setup_form(self, form): 

143 self.meta = form.meta 

144 return super().setup_form(form) 

145 

146 def generate_csrf_token(self, csrf_token_field): 

147 return generate_csrf( 

148 secret_key=self.meta.csrf_secret, token_key=self.meta.csrf_field_name 

149 ) 

150 

151 def validate_csrf_token(self, form, field): 

152 if g.get("csrf_valid", False): 

153 # already validated by CSRFProtect 

154 return 

155 

156 try: 

157 validate_csrf( 

158 field.data, 

159 self.meta.csrf_secret, 

160 self.meta.csrf_time_limit, 

161 self.meta.csrf_field_name, 

162 ) 

163 except ValidationError as e: 

164 logger.info(e.args[0]) 

165 raise 

166 

167 

168class CSRFProtect: 

169 """Enable CSRF protection globally for a Flask app. 

170 

171 :: 

172 

173 app = Flask(__name__) 

174 csrf = CSRFProtect(app) 

175 

176 Checks the ``csrf_token`` field sent with forms, or the ``X-CSRFToken`` 

177 header sent with JavaScript requests. Render the token in templates using 

178 ``{{ csrf_token() }}``. 

179 

180 See the :ref:`csrf` documentation. 

181 """ 

182 

183 def __init__(self, app=None): 

184 self._exempt_views = set() 

185 self._exempt_blueprints = set() 

186 

187 if app: 

188 self.init_app(app) 

189 

190 def init_app(self, app): 

191 app.extensions["csrf"] = self 

192 

193 app.config.setdefault("WTF_CSRF_ENABLED", True) 

194 app.config.setdefault("WTF_CSRF_CHECK_DEFAULT", True) 

195 app.config["WTF_CSRF_METHODS"] = set( 

196 app.config.get("WTF_CSRF_METHODS", ["POST", "PUT", "PATCH", "DELETE"]) 

197 ) 

198 app.config.setdefault("WTF_CSRF_FIELD_NAME", "csrf_token") 

199 app.config.setdefault("WTF_CSRF_HEADERS", ["X-CSRFToken", "X-CSRF-Token"]) 

200 app.config.setdefault("WTF_CSRF_TIME_LIMIT", 3600) 

201 app.config.setdefault("WTF_CSRF_SSL_STRICT", True) 

202 

203 app.jinja_env.globals["csrf_token"] = generate_csrf 

204 app.context_processor(lambda: {"csrf_token": generate_csrf}) 

205 

206 @app.before_request 

207 def csrf_protect(): 

208 if not app.config["WTF_CSRF_ENABLED"]: 

209 return 

210 

211 if not app.config["WTF_CSRF_CHECK_DEFAULT"]: 

212 return 

213 

214 if request.method not in app.config["WTF_CSRF_METHODS"]: 

215 return 

216 

217 if not request.endpoint: 

218 return 

219 

220 if app.blueprints.get(request.blueprint) in self._exempt_blueprints: 

221 return 

222 

223 view = app.view_functions.get(request.endpoint) 

224 dest = f"{view.__module__}.{view.__name__}" 

225 

226 if dest in self._exempt_views: 

227 return 

228 

229 self.protect() 

230 

231 def _get_csrf_token(self): 

232 # find the token in the form data 

233 field_name = current_app.config["WTF_CSRF_FIELD_NAME"] 

234 base_token = request.form.get(field_name) 

235 

236 if base_token: 

237 return base_token 

238 

239 # if the form has a prefix, the name will be {prefix}-csrf_token 

240 for key in request.form: 

241 if key.endswith(field_name): 

242 csrf_token = request.form[key] 

243 

244 if csrf_token: 

245 return csrf_token 

246 

247 # find the token in the headers 

248 for header_name in current_app.config["WTF_CSRF_HEADERS"]: 

249 csrf_token = request.headers.get(header_name) 

250 

251 if csrf_token: 

252 return csrf_token 

253 

254 return None 

255 

256 def protect(self): 

257 if request.method not in current_app.config["WTF_CSRF_METHODS"]: 

258 return 

259 

260 try: 

261 validate_csrf(self._get_csrf_token()) 

262 except ValidationError as e: 

263 logger.info(e.args[0]) 

264 self._error_response(e.args[0]) 

265 

266 if request.is_secure and current_app.config["WTF_CSRF_SSL_STRICT"]: 

267 if not request.referrer: 

268 self._error_response("The referrer header is missing.") 

269 

270 good_referrer = f"https://{request.host}/" 

271 

272 if not same_origin(request.referrer, good_referrer): 

273 self._error_response("The referrer does not match the host.") 

274 

275 g.csrf_valid = True # mark this request as CSRF valid 

276 

277 def exempt(self, view): 

278 """Mark a view or blueprint to be excluded from CSRF protection. 

279 

280 :: 

281 

282 @app.route('/some-view', methods=['POST']) 

283 @csrf.exempt 

284 def some_view(): 

285 ... 

286 

287 :: 

288 

289 bp = Blueprint(...) 

290 csrf.exempt(bp) 

291 

292 """ 

293 

294 if isinstance(view, Blueprint): 

295 self._exempt_blueprints.add(view) 

296 return view 

297 

298 if isinstance(view, str): 

299 view_location = view 

300 else: 

301 view_location = ".".join((view.__module__, view.__name__)) 

302 

303 self._exempt_views.add(view_location) 

304 return view 

305 

306 def _error_response(self, reason): 

307 raise CSRFError(reason) 

308 

309 

310class CSRFError(BadRequest): 

311 """Raise if the client sends invalid CSRF data with the request. 

312 

313 Generates a 400 Bad Request response with the failure reason by default. 

314 Customize the response by registering a handler with 

315 :meth:`flask.Flask.errorhandler`. 

316 """ 

317 

318 description = "CSRF validation failed." 

319 

320 

321def same_origin(current_uri, compare_uri): 

322 current = urlparse(current_uri) 

323 compare = urlparse(compare_uri) 

324 

325 return ( 

326 current.scheme == compare.scheme 

327 and current.hostname == compare.hostname 

328 and current.port == compare.port 

329 )