Last active
September 26, 2017 19:39
-
-
Save satiani/b20d7e81e48a041d241d61e9aab61be2 to your computer and use it in GitHub Desktop.
Gist for SQLAlchemy question
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from sqlalchemy import Column, Integer, String, create_engine, func | |
from sqlalchemy.ext.declarative import declarative_base, DeclarativeMeta | |
from sqlalchemy.orm.descriptor_props import CompositeProperty | |
from sqlalchemy.sql.functions import GenericFunction | |
from sqlalchemy.orm import composite, sessionmaker | |
from sqlalchemy.sql.expression import case | |
import sqlalchemy.types as types | |
class Money(object): | |
def __init__(self, amount, currency): | |
self.amount = amount | |
self.currency = currency | |
def __ne__(self, other): | |
assert isinstance(other, Money) | |
return self.amount != other.amount or self.currency != other.currency | |
def __eq__(self, other): | |
assert isinstance(other, Money) | |
return self.amount == other.amount and self.currency == other.currency | |
def __gt__(self, other): | |
assert isinstance(other, Money) | |
assert other.currency == self.currency | |
return self.amount > other.amount | |
def __lt__(self, other): | |
assert isinstance(other, Money) | |
assert other.currency == self.currency | |
return self.amount < other.amount | |
def __add__(self, other): | |
assert isinstance(other, Money) | |
assert other.currency == self.currency | |
return Money(self.amount + other.amount, self.currency) | |
def __mul__(self, other): | |
assert isinstance(other, Money) | |
assert other.currency == self.currency | |
return Money(self.amount * other.amount, self.currency) | |
def __div__(self, other): | |
assert isinstance(other, Money) | |
assert other.currency == self.currency | |
return Money(self.amount / other.amount, self.currency) | |
def __neg__(self): | |
return Money(-self.amount, self.currency) | |
def __composite_values__(self): | |
return self.amount, self.currency | |
class MoneyComparator(CompositeProperty.Comparator): | |
def __add__(self, other): | |
clauses = self.__clause_element__().clauses | |
other_clauses = self.__clause_element__().clauses | |
return Money(clauses[0], clauses[1]) + Money(other_clauses[0], other_clauses[1]) | |
def __neg__(self): | |
clauses = self.__clause_element__().clauses | |
return Money(-clauses[0], clauses[1]) | |
class MoneyComposite(CompositeProperty): | |
def __init__(self): | |
super(MoneyComposite, self).__init__( | |
Money, | |
Column(Integer), | |
Column(String(length = 3)), | |
comparator_factory=MoneyComparator | |
) | |
def get_column_names(self, key): | |
return [ | |
'{}_amount'.format(key), | |
'{}_currency'.format(key) | |
] | |
def undefer_column_names(self, key): | |
""" This is meant to simulate a 'declarative' way of | |
declaring money composite columns. See how they | |
are defined in BankAccount below. | |
CustomDeclarativeMeta calls this function to initialize | |
the appropriate composite columns on the table. | |
""" | |
for col, name in zip(self.columns, self.get_column_names(key)): | |
col.name = name | |
col.key = name | |
class CustomDeclarativeMeta(DeclarativeMeta): | |
def __new__(cls, name, bases, d): | |
for k, v in d.iteritems(): | |
if isinstance(v, MoneyComposite): | |
v.undefer_column_names(k) | |
return DeclarativeMeta.__new__(cls, name, bases, d) | |
Base = declarative_base(metaclass = CustomDeclarativeMeta) | |
class BankAccount(Base): | |
__tablename__ = 'bank_account' | |
id = Column(Integer, primary_key = True) | |
category = Column(String(100)) | |
# CustomDeclarativeMeta's __new__ will translate this into two columns balance_amount | |
# and balance_currency | |
balance_even = MoneyComposite() | |
balance_odd = MoneyComposite() | |
def generate_money_type_engine(currency): | |
class MoneyTypeEngine(types.TypeEngine): | |
def result_processor(self, dialect, coltype): | |
def result_func(val): | |
if val is None: | |
return val | |
r = Money(val, currency) | |
return r | |
return result_func | |
return MoneyTypeEngine | |
class MoneySumFunction(GenericFunction): | |
name = 'sum' | |
identifier = 'sum_money' | |
def __init__(self, *args, **kwargs): | |
sa_property_proxy = args[0] | |
composite_property = sa_property_proxy.property | |
assert isinstance(composite_property, MoneyComposite), \ | |
"Only MoneyComposite columns may be used in this function" | |
currency = kwargs.pop('currency', None) | |
assert currency != None | |
kwargs['type_'] = generate_money_type_engine(currency) | |
amount_column = composite_property.columns[0] | |
# get column through declarative class so that it is fully instrumented by | |
# sqlalchemy | |
args = [getattr(sa_property_proxy.class_, amount_column.name),] | |
GenericFunction.__init__(self, *args, **kwargs) | |
engine = create_engine('sqlite:///') | |
Base.metadata.create_all(engine) | |
Base.metadata.bind = engine | |
DBSession = sessionmaker(bind=engine) | |
session = DBSession() | |
for i in range(10): | |
new_account = BankAccount( | |
category = 'even' if i % 2 == 0 else 'odd', | |
balance_even = Money(3, 'USD'), | |
balance_odd = Money(3, 'USD') | |
) | |
session.add(new_account) | |
session.commit() | |
total_balances = session.query(func.sum_money(BankAccount.balance_even, currency='USD')).scalar() | |
# should print 30 | |
print total_balances.amount | |
# this fails | |
total_balances = session.query(func.sum_money(BankAccount.balance_even + BankAccount.balance_odd, currency='USD')).scalar() | |
# this fails | |
even_against_odd = session.query(func.sum_money( | |
case([ | |
[BankAccount.category == 'even', | |
BankAccount.balance_even], | |
[BankAccount.category == 'odd', | |
BankAccount.balance_odd] | |
]))).scalar() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment