Skip to content

Instantly share code, notes, and snippets.

@rslinckx
Created October 14, 2010 21:15
Show Gist options
  • Save rslinckx/627062 to your computer and use it in GitHub Desktop.
Save rslinckx/627062 to your computer and use it in GitHub Desktop.
CsrfProtector for Flask applications
import logging
log = logging.getLogger(__name__)
from datetime import timedelta
from flask.helpers import _endpoint_from_view_func
from flask import request, g, abort
import os
import hmac
from hashlib import sha1
__all__ = ['CsrfProtector']
class CsrfToken(object):
def __init__(self, secret, token=None):
self.secret = secret
self.new = False
self.token = token or self.generate()
self.used = False
def __call__(self):
self.used = True
return self.token
def __repr__(self):
return 'CsrfToken(%s)' % self.token
def generate(self):
self.new = True
return hmac.new(self.secret, os.urandom(64), sha1).hexdigest()
def check(self):
if request.method not in ("PUT", "POST", "DELETE"):
return True
if request.is_xhr:
return True
if request.is_secure:
if not request.referrer:
log.info('Invalid CSRF: HTTPS with no referrer')
return False
if not request.referrer.startswith(request.host):
log.info('Invalid CSRF: HTTPS referrer does not match: %r != %r', request.referrer, request.host)
return False
form_csrf = request.form.get('csrf_token', None)
if form_csrf != self.token:
log.info('Invalid CSRF: Form token does not match cookie token: %r != %r', form_csrf, self.token)
return False
return True
class CsrfProtector(object):
def __init__(self, app=None):
self.exempt_view_funcs = set()
if app is not None:
self.init_app(app)
def init_app(self, app):
self.app = app
app.before_request(self.on_before_request)
app.after_request(self.on_after_request)
self.secret = app.config['SECRET_KEY']
self.csrf_name = app.config.get('CSRF_COOKIE_NAME', 'csrf')
self.csrf_timeout = app.config.get('CSRF_COOKIE_TIMEOUT', timedelta(days=5))
self.csrf_disable = app.config.get('CSRF_DISABLE', app.config.get('TESTING', False))
def exempt(self, f):
self.exempt_view_funcs.add(f)
return f
def on_before_request(self):
if self.csrf_name in request.cookies:
g.csrf_token = CsrfToken(self.secret, request.cookies[self.csrf_name])
else:
g.csrf_token = CsrfToken(self.secret)
view_func = self.app.view_functions.get(request.endpoint)
if not view_func or view_func in self.exempt_view_funcs:
return
if self.csrf_disable:
return
if not g.csrf_token.check():
abort(400)
def on_after_request(self, response):
if not hasattr(g, 'csrf_token'):
return response
if not g.csrf_token.used:
return response
response.set_cookie(self.csrf_name, g.csrf_token.token, max_age=self.csrf_timeout)
response.vary.add('Cookie')
return response
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment