Created
April 29, 2013 16:30
-
-
Save djtriptych/5482779 to your computer and use it in GitHub Desktop.
Simple Redis-backed ORM
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
""" | |
db - Save and load data models in Redis. | |
~~ | |
""" | |
# TODO: Key object to manage separating class name from key name. | |
# FUTURE: Automatic and Manual indexes. | |
# FUTURE: Store Models in mongo, keys in redis | |
# Standard | |
import cPickle | |
import datetime | |
import hashlib | |
import itertools | |
import json | |
import os | |
import random | |
import time | |
# 3rd party | |
import redis | |
from dateutil import parser | |
# The global connection will be loaded from a configuration file or set manually | |
# by the user | |
CONN = None | |
SETTINGS_FILE = '~/.redisdb' | |
def configure(settings_file_path): | |
global CONN | |
try: | |
with open(os.path.expanduser(settings_file_path)) as settings_file: | |
settings = json.load(settings_file) | |
try: | |
# Note that this will NOT fail even if there is no server at this | |
# host/port until a command is sent using the bogus connection. | |
CONN = redis.StrictRedis( | |
host = settings['host'], | |
port = settings['port'], | |
db = settings['db'], | |
password = settings.get('password', None), | |
) | |
except AttributeError: | |
raise ValueError('Invalid settings') | |
except IOError: | |
pass | |
# Try to load default settings. If this fails for any reason, connection will | |
# still be None. | |
try: | |
if os.path.exists('.redisdb'): | |
configure('.redisdb') | |
else: | |
configure(SETTINGS_FILE) | |
except: | |
raise | |
pass | |
class InvalidKeyError(Exception): pass | |
class Key(object): | |
""" Create cryptographic hash for a Model. | |
Used to create a unique filename for the Model. | |
""" | |
# Used to separate key components | |
SEP = ':' | |
ALL = '__all_set__' | |
def __init__(self, kind, *args): | |
""" | |
# Good | |
Key(Job) | |
Key(Job, 'test') | |
Key(Job, j.key()) | |
Key(Job, event.id, account.id) | |
# Bad | |
Key(Job, j.key(), 'test') # If Key is used, must be only argument. | |
Key('test') # Missing Kind | |
""" | |
try: | |
if not isinstance(kind, (str, unicode)): | |
if not issubclass(kind, Model): | |
raise InvalidKeyError('kind must be Model class or string') | |
except TypeError: | |
raise InvalidKeyError('kind must be Model class or string') | |
if len(args) > 1 and isinstance(args[0], Key): | |
raise InvalidKeyError('If Key is used as arg, must be only arg') | |
# Generate a random key if no args. | |
if len(args) == 0: | |
args = (hashlib.md5(str(random.random())).hexdigest(), ) | |
# Support creating a Key given a Key | |
elif isinstance(args[0], Key): | |
args = (args[0].name(), ) | |
if (isinstance(kind, (str, unicode))): | |
self.__kind = kind.lower() | |
else: | |
self.__kind = kind.__name__.lower() | |
self.__name = self.join(*map(str, args)) | |
def name(self): | |
return self.__name | |
def kind(self): | |
return self.__kind | |
def key(self): | |
return self.join(self.kind(), self.name()) | |
@classmethod | |
def from_raw(cls, value): | |
kind, _, name = value.partition(Key.SEP) | |
return Key(kind, name) | |
@classmethod | |
def join(cls, *args): | |
return Key.SEP.join(map(str, args)) | |
@classmethod | |
def AssociationSet(cls, one_obj, many_class): | |
# Model a one-to-many relationship by defining a set to hold the many | |
# objects given the one. | |
# Convert Model instance to it's class | |
if isinstance(many_class, Model): | |
many_class = many_class.__class__ | |
assert isinstance(one_obj, Model) | |
set_key = Key.join( | |
'assoc', | |
one_obj.__class__.__name__.lower(), | |
one_obj.key().name(), | |
many_class.__name__.lower()) | |
return set_key | |
@staticmethod | |
def ModelSet(cls_or_obj): | |
""" Given a Model objects, return a key naming a set used to store keys of all instances of | |
this model. """ | |
# Get the action Model class if given an object | |
if isinstance(cls_or_obj, Model): | |
cls = cls_or_obj.__class__ | |
else: | |
cls = cls_or_obj | |
assert issubclass(cls, Model), "obj must be a Model class" | |
assert cls is not Model, "Model base class cannot be keyed" | |
cls_name = cls.__name__.lower() | |
return Key.join(cls_name, Key.ALL) | |
class Property(object): | |
""" Handle conversion of raw values to Properties """ | |
def __init__(self, default=None, choices=None, required=False): | |
self.default = default | |
self.choices = choices | |
self.required = required | |
self.name = None | |
def __get__(self, instance, owner): | |
# Allow access to Property when __get__ called on class. | |
if instance is None: | |
return self | |
# __get__ called on object; use Property descriptor | |
return self._get_value(instance, owner) | |
def __set__(self, instance, value): | |
validated = self._validate(value) | |
if validated is not None: | |
value = validated | |
if self.choices is not None: | |
if value not in self.choices: | |
raise TypeError('%s not an allowed choice' % value) | |
self._set_value(instance, value) | |
def __delete__(self, instance): | |
try: | |
del instance._values[self.name] | |
except AttributeError: | |
pass | |
def _validate(self, value): | |
return None | |
def _set_value(self, instance, value): | |
instance._values[self.name] = value | |
def _get_value(self, instance, owner=None): | |
return instance._values.get(self.name, self.default) | |
def _to_db_value(self, instance): | |
""" Return serialized value. Useful for stuff like encoding a dict as a | |
string. """ | |
return self._get_value(instance) | |
def _from_db_value(self, instance, value): | |
""" Create internal value from serialized value """ | |
self._set_value(instance, value) | |
def _fix_up(self, cls, name): | |
self.cls = cls | |
self.name = name | |
class KeyProperty(Property): | |
def _to_db_value(self, instance): | |
""" We store only the key of the referenced property """ | |
obj = self._get_value(instance) | |
return obj.key().key() | |
def _from_db_value(self, instance, value): | |
""" Load referenced object via key """ | |
key = Key.from_raw(value) | |
obj = self.model.load(key) | |
self._set_value(instance, obj) | |
class ReferenceProperty(KeyProperty): | |
def __init__(self, model, collection_name, *args, **kwargs): | |
self.model = model | |
self.collection_name = collection_name | |
super(ReferenceProperty, self).__init__(*args, **kwargs) | |
def _validate(self, value): | |
if not isinstance(value, self.model): | |
raise TypeError('Reference property must be Model instance') | |
def _fix_up(self, cls, name): | |
super(ReferenceProperty, self)._fix_up(cls, name) | |
# Assign collection_name property to referenced Model | |
existing_attr = getattr(self.model, self.collection_name, None) | |
if existing_attr is None: | |
setattr(self.model, self.collection_name, | |
_ReverseReferenceProperty(self.cls)) | |
else: | |
raise ValueError('attribute %s already exists' % self.collection_name) | |
class _ReverseReferenceProperty(Property): | |
""" A collection of associated models, used to model entity-relationships. | |
class Class: | |
name = db.StringProperty() | |
class Student(Property): | |
# Magically assign Class._ReverseReferenceProperty here | |
class = ReferenceProperty(Class, 'students') | |
name = db.StringProperty() | |
c = Class() | |
c.name = 'SICP' | |
s = Student() | |
s.class = c | |
s.name = 'Sue' | |
s.class # = <Class c> | |
c.students # = [<Student s>] | |
""" | |
def __init__(self, reference_model): | |
self.reference_model = reference_model | |
super(_ReverseReferenceProperty, self).__init__(self) | |
def __get__(self, model_instance, model_class): | |
""" Return all associated objects of type self.reference_model """ | |
set_key = Key.AssociationSet(model_instance, self.reference_model) | |
for raw_key in CONN.smembers(set_key): | |
model = self.reference_model.load(Key.from_raw(raw_key)) | |
if model is not None: | |
yield model | |
def __set__(self): | |
raise ValueError('Cannot set reverse reference property.') | |
@staticmethod | |
def set_key(model_instance, prop): | |
return Key.AssociationSet(model_instance, prop.reference_model) | |
class DateTimeProperty(Property): | |
def __init__(self, auto_add_now=False, *args, **kwargs): | |
default = datetime.datetime.utcnow() if auto_add_now else None | |
super(DateTimeProperty, self).__init__(default=default, *args, **kwargs) | |
def _to_db_value(self, instance): | |
date = self._get_value(instance) | |
if date is not None: | |
seconds = (date - datetime.datetime(1970, 1, 1)).total_seconds() | |
return str(float(seconds)) | |
else: | |
return None | |
def _from_db_value(self, instance, value): | |
try: | |
seconds = float(value) | |
date = datetime.datetime.utcfromtimestamp(seconds) | |
except ValueError: | |
date = parser.parse(value) | |
self._set_value(instance, date) | |
def _validate(self, value): | |
if not isinstance(value, datetime.datetime): | |
raise TypeError('Bad Date Value: %s ' % value) | |
class PickleProperty(Property): | |
def _to_db_value(self, instance): | |
""" Serialize arbitrary Python value to string via Pickle """ | |
data = self._get_value(instance) | |
return cPickle.dumps(data) | |
def _from_db_value(self, instance, value): | |
""" Deserialize value from pickled string representation """ | |
data = cPickle.loads(value) | |
self._set_value(instance, data) | |
class JSONProperty(Property): | |
def _to_db_value(self, instance): | |
""" Encode dict as json string. """ | |
data = self._get_value(instance) | |
return json.dumps(data) | |
def _from_db_value(self, instance, value): | |
""" Decode dict from json string. """ | |
data = json.loads(value) | |
self._set_value(instance, data) | |
class StringProperty(Property): | |
def _validate(self, v): | |
try: | |
str.encode('utf-8') | |
except UnicodeError: | |
raise TypeError('UTF-8 strings only') | |
try: | |
return str(v) | |
except TypeError: | |
raise TypeError('Invalid string') | |
class PasswordProperty(StringProperty): | |
pass | |
class EmailProperty(StringProperty): | |
def _validate(self, v): | |
if not '@' in v: | |
raise TypeError('Invalid email address') | |
class FloatProperty(Property): | |
def _from_db_value(self, instance, value): | |
""" Parse float value from string """ | |
self._set_value(instance, float(value)) | |
def _validate(self, value): | |
try: | |
return float(value) | |
except: | |
raise TypeError('Invalid float') | |
class IntegerProperty(Property): | |
def _from_db_value(self, instance, value): | |
""" Parse int value from string """ | |
self._set_value(instance, int(value)) | |
def _validate(self, value): | |
try: | |
return int(value) | |
except: | |
raise TypeError('Invalid int') | |
class BooleanProperty(Property): | |
def __init__(self, choices=None, *args, **kwargs): | |
if choices is not None: | |
raise ValueError('BooleanProperty does not support `choices` ' | |
'argument') | |
self.choices = (True, False) | |
super(BooleanProperty, self).__init__(*args, **kwargs) | |
def _to_db_value(self, instance): | |
""" Convert boolean to 1 or 0 """ | |
data = self._get_value(instance) | |
return 1 if data else 0 | |
def _from_db_value(self, instance, value): | |
""" Load boolean from 1 or 0 """ | |
self._set_value(instance, bool(value)) | |
class MetaModel(type): | |
def __init__(cls, name, bases, classdict): | |
""" Create a Model class """ | |
super(MetaModel, cls).__init__(name, bases, classdict) | |
cls._fix_up_properties() | |
class Model(object): | |
__metaclass__ = MetaModel | |
class MissingKeyError: pass | |
class MissingRequiredValueError: pass | |
def __init__(self, key=None): | |
key = () if key is None else key | |
if not isinstance(key, (str, unicode, tuple)): | |
raise InvalidKeyError | |
if isinstance(key, (str, unicode)): | |
key = (key, ) | |
self.__key = Key(self.__class__, *key) | |
self._values = {} | |
def key(self): | |
return self.__key | |
@classmethod | |
def _fix_up_properties(cls): | |
# Tell all Property objects their assigned name. | |
for name in set(dir(cls)): | |
attr = getattr(cls, name, None) | |
if isinstance(attr, Property): | |
attr._fix_up(cls, name) | |
@classmethod | |
def all(cls): | |
""" Generate all Models of a given kind. """ | |
set_key = Key.ModelSet(cls) | |
for raw_key in CONN.smembers(set_key): | |
key = Key.from_raw(raw_key) | |
yield cls.load(key) | |
@classmethod | |
def load(cls, *args): | |
""" Load a Model by key_name, key object, or argument list used to | |
construct a key """ | |
key = Key(cls, *args) | |
data = CONN.hgetall(key.key()) | |
if not data: # Empty dict when key is missing | |
return None | |
# Create bare model instance. | |
instance = cls(key=args) | |
# Assign property values from DB. | |
for name, prop in instance._properties().iteritems(): | |
value = data.get(name) | |
if value is not None: | |
prop._from_db_value(instance, value) | |
else: | |
prop._set_value(instance, prop.default) | |
return instance | |
def save(self): | |
# Gather data from properties. | |
data = {} | |
for name, prop in self._properties().iteritems(): | |
data[name] = prop._to_db_value(self) | |
if data[name] is None: | |
if prop.required is True: | |
raise self.MissingRequiredValueError() | |
# Commit data in transaction. | |
with CONN.pipeline() as pipe: | |
_k = self.key().key() | |
# Attempt to save Model. If save is successful, also save Model's key in | |
# a set holding every key of this Model's class. | |
result = pipe.hmset(_k, data) | |
if not result: # XXX: Why might the save fail? | |
return None | |
else: | |
set_key = Key.ModelSet(self) | |
pipe.sadd(set_key, _k) | |
# Remove hash keys that have a value of None! | |
# When loading this object from the database, the value of these | |
# missing properies will be the property default. | |
none_values = [n for n in data if data[n] is None] | |
if none_values: | |
pipe.hdel(_k, *none_values) | |
# If there are any reference properties, they are saved as an | |
# associative set. | |
for name, prop in self._properties().iteritems(): | |
if isinstance(prop, ReferenceProperty): | |
set_key = Key.AssociationSet(prop._get_value(self), self) | |
pipe.sadd(set_key, _k) | |
pipe.execute() | |
def delete(self): | |
""" Delete this model. Return true value if model was actually deleted, | |
else false. """ | |
# Remove model | |
CONN.delete(self.key().key()) | |
# Remove model from ModelSet | |
CONN.srem(Key.ModelSet(self), self.key().key()) | |
# Remove set(s) of associated models | |
for name, prop in self._properties().iteritems(): | |
if isinstance(prop, _ReverseReferenceProperty): | |
set_key = _ReverseReferenceProperty.set_key(self, | |
prop.reference_model) | |
CONN.delete(set_key) | |
def _properties(self): | |
cls = self.__class__ | |
return {attr:getattr(cls, attr) | |
for attr in dir(cls) | |
if isinstance(getattr(cls, attr), Property)} | |
def test_key(): | |
""" Test key creation and access... """ | |
global CONN | |
CONN = redis.StrictRedis( | |
host = 'localhost', | |
port = 7777, | |
db = 15, | |
password = None | |
) | |
import pytest | |
CONN.flushdb() | |
class Radiohead(Model): | |
name = StringProperty() | |
key = Key(Radiohead, 'Thom') | |
assert key.key() == 'radiohead:Thom' | |
assert key.name() == 'Thom' | |
assert key.kind() == 'radiohead' | |
r = Radiohead(key='Colin') | |
assert r.key().name() == 'Colin' | |
with pytest.raises(InvalidKeyError): | |
k = Key('Thom') | |
k = Key(Radiohead, key, 'Thom') | |
c = Radiohead(key=('Thom',)) | |
c.name = 'Thom' | |
c.save() | |
thom = Radiohead.load('Thom') | |
thom.delete() # There's... such a chill... | |
def test_ref(): | |
""" Test reference properties / association sets """ | |
class Course(Model): | |
name = StringProperty() | |
schedule = DateTimeProperty() | |
class Student(Model): | |
name = StringProperty() | |
course = ReferenceProperty(Course, 'students') | |
try: | |
s = Student() | |
s.name = 'Sue' | |
t = Student() | |
t.name = 'Tony' | |
c = Course() | |
c.name = 'S.I.C.P.' | |
c.schedule = datetime.datetime.now() | |
s.course = c | |
t.course = c | |
s.save() | |
t.save() | |
c.save() | |
assert Key.AssociationSet(c, Student) == 'assoc:course:{}:student'.format(c.key().name()) | |
assert len(list(c.students)) == 2 | |
d = Course.load(c.key()) | |
assert d is not None | |
finally: | |
# Delete everything. | |
for o in Student.all(): | |
o.delete() | |
for o in Course.all(): | |
o.delete() | |
def test_props(): | |
class Bag(Model): | |
_int = IntegerProperty() | |
_bool = BooleanProperty() | |
_str = StringProperty() | |
_float = FloatProperty() | |
_date = DateTimeProperty(auto_add_now=True) | |
_json = JSONProperty() | |
_none = FloatProperty() | |
_pickle = PickleProperty() | |
try: | |
b = Bag() | |
b._int = 42 | |
b._bool = True | |
b._str = 'Unladen swallow' | |
b._float = 0.1 | |
b._json = [1, 2] | |
b._none = 1 | |
del b._none | |
class Foo: | |
def __init__(self, a, b, c): | |
self.a = a | |
self.b = b | |
self.c = c | |
b._pickle = Model() | |
b.save() | |
c = Bag.load(b.key()) | |
assert c._int == 42 | |
assert c._str == 'Unladen swallow' | |
assert c._float == 0.1 | |
assert c._bool is True | |
assert c._json == [1, 2] | |
assert type(c._pickle) is Model | |
assert c._none is None | |
k = c.key() | |
c.save() | |
b = Bag.load(k) | |
finally: | |
pass | |
# c.delete() | |
# b.delete() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment