Created
October 14, 2010 21:15
-
-
Save rslinckx/627062 to your computer and use it in GitHub Desktop.
CsrfProtector for Flask applications
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import logging | |
log = logging.getLogger(__name__) | |
from datetime import timedelta | |
from flask.helpers import _endpoint_from_view_func | |
from flask import request, g, abort | |
import os | |
import hmac | |
from hashlib import sha1 | |
__all__ = ['CsrfProtector'] | |
class CsrfToken(object): | |
def __init__(self, secret, token=None): | |
self.secret = secret | |
self.new = False | |
self.token = token or self.generate() | |
self.used = False | |
def __call__(self): | |
self.used = True | |
return self.token | |
def __repr__(self): | |
return 'CsrfToken(%s)' % self.token | |
def generate(self): | |
self.new = True | |
return hmac.new(self.secret, os.urandom(64), sha1).hexdigest() | |
def check(self): | |
if request.method not in ("PUT", "POST", "DELETE"): | |
return True | |
if request.is_xhr: | |
return True | |
if request.is_secure: | |
if not request.referrer: | |
log.info('Invalid CSRF: HTTPS with no referrer') | |
return False | |
if not request.referrer.startswith(request.host): | |
log.info('Invalid CSRF: HTTPS referrer does not match: %r != %r', request.referrer, request.host) | |
return False | |
form_csrf = request.form.get('csrf_token', None) | |
if form_csrf != self.token: | |
log.info('Invalid CSRF: Form token does not match cookie token: %r != %r', form_csrf, self.token) | |
return False | |
return True | |
class CsrfProtector(object): | |
def __init__(self, app=None): | |
self.exempt_view_funcs = set() | |
if app is not None: | |
self.init_app(app) | |
def init_app(self, app): | |
self.app = app | |
app.before_request(self.on_before_request) | |
app.after_request(self.on_after_request) | |
self.secret = app.config['SECRET_KEY'] | |
self.csrf_name = app.config.get('CSRF_COOKIE_NAME', 'csrf') | |
self.csrf_timeout = app.config.get('CSRF_COOKIE_TIMEOUT', timedelta(days=5)) | |
self.csrf_disable = app.config.get('CSRF_DISABLE', app.config.get('TESTING', False)) | |
def exempt(self, f): | |
self.exempt_view_funcs.add(f) | |
return f | |
def on_before_request(self): | |
if self.csrf_name in request.cookies: | |
g.csrf_token = CsrfToken(self.secret, request.cookies[self.csrf_name]) | |
else: | |
g.csrf_token = CsrfToken(self.secret) | |
view_func = self.app.view_functions.get(request.endpoint) | |
if not view_func or view_func in self.exempt_view_funcs: | |
return | |
if self.csrf_disable: | |
return | |
if not g.csrf_token.check(): | |
abort(400) | |
def on_after_request(self, response): | |
if not hasattr(g, 'csrf_token'): | |
return response | |
if not g.csrf_token.used: | |
return response | |
response.set_cookie(self.csrf_name, g.csrf_token.token, max_age=self.csrf_timeout) | |
response.vary.add('Cookie') | |
return response |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment