Created
March 9, 2011 19:41
-
-
Save slacy/862824 to your computer and use it in GitHub Desktop.
Pyramid sessions object implemented on top of MongoDB using minimongo.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import datetime | |
import time | |
import random | |
import logging | |
from minimongo.model import Model, MongoCollection | |
from pyramid.interfaces import ISession | |
from zope.interface import implements | |
from pyramid.response import Response | |
# The Session object that we construct in memory has some extra bookkeeping | |
# fields on it. We don't save these in the Session with the DB, but we just | |
# keep them in RAM, here in this special object that isn't saved to the DB. | |
class SessionGuts(object): | |
"""Guts / tracking information for the session object.""" | |
def __init__(self, request, dirty, rotate_session_key, rotate_csrf_token): | |
self.request = request | |
self.dirty = dirty | |
self.rotate_session_key = rotate_session_key | |
self.rotate_csrf_token = rotate_csrf_token | |
def get_session_key(): | |
"""Get a new key for a session and return it.""" | |
# We construct the session key with the first part being a timestamp, | |
# and the second part being a session key, so that they are in roughly | |
# ascending order, for MongoDB to more efficiently index the session | |
# keys. | |
timestamp = "%08x" % time.time() | |
session_key = '%s%s' % (timestamp, "%016x" % random.getrandbits(64)) | |
return session_key | |
def wrap_access(wrapped): | |
"""Wrap access to the session object and set the dirty bit any time | |
anything is accessed. """ | |
def accessed(session, *args, **kwargs): | |
"""The wrapping method""" | |
session.accessed = datetime.datetime.now() | |
# Wrapped methods get called when we're loading and saving from the | |
# DB, and we don't have _guts yet, so we skip the dirty bits when we | |
# don't have _guts. | |
if hasattr(session, '_guts'): | |
if not session._guts.dirty: | |
session._guts.request.add_response_callback( | |
session._set_session_cookie) | |
session._guts.dirty = True | |
return wrapped(session, *args, **kwargs) | |
accessed.__doc__ = wrapped.__doc__ | |
return accessed | |
def MinimongoSessionFactoryConfig( | |
timeout=30 * 86400, # 30-day sessions | |
cookie_name='session', | |
cookie_max_age=None, | |
cookie_path='/', | |
cookie_domain=None, | |
cookie_secure=False, | |
cookie_httponly=False, | |
cookie_on_exception=False, | |
): | |
class MinimongoSession(Model): | |
mongo = MongoCollection(database='web', collection='session') | |
implements(ISession) | |
# configuration parameters | |
_cookie_name = cookie_name | |
_cookie_max_age = cookie_max_age | |
_cookie_path = cookie_path | |
_cookie_domain = cookie_domain | |
_cookie_secure = cookie_secure | |
_cookie_httponly = cookie_httponly | |
_cookie_on_exception = cookie_on_exception | |
_timeout = timeout | |
def __init__(self, request=None): | |
if not request: | |
return | |
# This entire object represents the session, and is stored in | |
# MongoDB, except for the _guts field, which isn't saved. | |
self._guts = SessionGuts(request=request, | |
dirty=False, | |
rotate_session_key=False, | |
rotate_csrf_token=False) | |
now = datetime.datetime.now() | |
cookieval = None | |
if request: | |
cookieval = request.cookies.get(self._cookie_name) | |
if cookieval is not None: | |
session_object = MinimongoSession.collection.find_one( | |
{'_id': cookieval}) | |
# Take what was stored in the DB, and update 'self' to | |
# mirror all it's fields. This is what we'll save back to | |
# the DB when things change. | |
if session_object: | |
self.update(session_object) | |
# Check to see if we loaded a session from the DB, if we didn't | |
# then initialize one from scratch. | |
if not '_id' in self: | |
self['_id'] = get_session_key() | |
self['created'] = now | |
self['session_id'] = '%016x' % random.getrandbits(128) | |
if 'accessed' not in self: | |
self['accessed'] = datetime.datetime.now() | |
if now - self['accessed'] > datetime.timedelta( | |
seconds=self._timeout): | |
# Remove the old session, and replace it with a nice new one. | |
self.invalidate() | |
# Setting values via dict-access will set self._dirty, causing | |
# the session to be written out at the end of the request. | |
self['accessed'] = now | |
# ISession methods | |
def changed(self): | |
""" This is intentionally a noop; the session is | |
serialized on every access, so unnecessary""" | |
pass | |
def invalidate(self): | |
"""Invalidate the current session. Essentially, nukes the | |
session and replaces it with a new one.""" | |
old_guts = self._guts | |
self.remove() | |
self.clear() # XXX probably needs to unset cookie | |
# Put the guts back (has tracking data we don't really want to lose.) | |
self._guts = old_guts | |
self._id = get_session_key() | |
# A fresh clean, session: | |
self['created'] = now | |
self['session_id'] = '%016x' % random.getrandbits(128) | |
# non-modifying dictionary methods | |
get = wrap_access(Model.get) | |
__getitem__ = wrap_access(Model.__getitem__) | |
items = wrap_access(Model.items) | |
iteritems = wrap_access(Model.iteritems) | |
values = wrap_access(Model.values) | |
itervalues = wrap_access(Model.itervalues) | |
keys = wrap_access(Model.keys) | |
iterkeys = wrap_access(Model.iterkeys) | |
__contains__ = wrap_access(Model.__contains__) | |
has_key = wrap_access(Model.has_key) | |
__len__ = wrap_access(Model.__len__) | |
__iter__ = wrap_access(Model.__iter__) | |
# modifying dictionary methods | |
clear = wrap_access(Model.clear) | |
update = wrap_access(Model.update) | |
setdefault = wrap_access(Model.setdefault) | |
pop = wrap_access(Model.pop) | |
popitem = wrap_access(Model.popitem) | |
__setitem__ = wrap_access(Model.__setitem__) | |
__delitem__ = wrap_access(Model.__delitem__) | |
def _set_session_cookie(self, request, response): | |
"""Get the cookie value for this session, also, since this is | |
called at the end of the request, cleans up some things.""" | |
# At this point, we've saved the session, set the cookie, and we | |
# think the request is over, so we break a circular reference | |
# between the request and the session. | |
self._guts.request = None | |
# If the response isn't a Response, then we just bail and don't | |
# set the cookie (this time). This is used for unit testing, I | |
# think, when the Response is a mock. Maybe this code should go | |
# away? | |
if not isinstance(response, Response): | |
return True | |
# If the view has requested that the session key be rotated, | |
# then we do that now (at the end of the request) | |
if self._guts.rotate_session_key: | |
old_id = self._id | |
self._id = get_session_key() | |
self.collection.update( | |
{'_id': old_id}, | |
self, upsert=True) | |
self._guts.dirty = True | |
# This could be optimized to not do 2 writes. | |
# If the view has requested that the CSRF token be rotated, then | |
# we do that now (at the end of the request) | |
if self._guts.rotate_csrf_token: | |
self.new_csrf_token() | |
self._guts.dirty = True | |
# Save the current state out to the db. | |
if self._guts.dirty: | |
# Rip out the guts, save the object, and then put the guts | |
# back. | |
guts = self._guts | |
del self._guts | |
self.save() | |
self._guts = guts | |
self._guts.dirty = False | |
else: | |
pass | |
response.set_cookie( | |
self._cookie_name, | |
value=self._id, | |
max_age=self._cookie_max_age, | |
path=self._cookie_path, | |
domain=self._cookie_domain, | |
secure=self._cookie_secure, | |
httponly=self._cookie_httponly, | |
) | |
return True | |
def rotate_session_key(self): | |
"""Rotate the session key at the end of this request""" | |
self._guts.rotate_session_key = True | |
def rotate_csrf_token(self): | |
"""Rotate the session key at the end of this request""" | |
self._guts.rotate_csrf_token = True | |
@wrap_access | |
def new_csrf_token(self): | |
"""Generate, set and return a new CSRF token for this session.""" | |
token = "%016x" % random.getrandbits(128) | |
# This assignment will set the dirty bit automatically. | |
self['csrft'] = token | |
return token | |
@wrap_access | |
def get_csrf_token(self): | |
"""Get the current CSRF token value, generating a new one if | |
necessary.""" | |
token = self.get('csrft', None) | |
if token is None: | |
token = self.new_csrf_token() | |
return token | |
@wrap_access | |
def csrf_valid(self, csrf_token): | |
"""Check to see if the CSRF value is valid, and rotate the CSRF | |
value at the end of the request.""" | |
equal = (csrf_token == self.get_csrf_token()) | |
if not equal: | |
logging.warning("CSRF INVALID: %s != %s", | |
csrf_token, self.get_csrf_token()) | |
# Once the CSRF is checked, we say that we should rotate the | |
# CSRF token at the end of this request. We only rotate the | |
# CSRF token if it's actually checked, otherwise we just keep | |
# plugging along with the same one. | |
self.rotate_csrf_token() | |
return equal | |
return MinimongoSession |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment