Skip to content

Instantly share code, notes, and snippets.

@tonyseek
Created September 18, 2012 04:55
Show Gist options
  • Save tonyseek/3741335 to your computer and use it in GitHub Desktop.
Save tonyseek/3741335 to your computer and use it in GitHub Desktop.
A utility to give a object a gateway to exchange data with outside sources.
#!/usr/bin/env python
#-*- coding:utf-8 -*-
import re
import copy
import logging
from db.sqlstore import store
class Gateway(object):
"""A gateway to fetch and store a subject object."""
def __init__(self):
self.subject = None
def __get__(self, instance, instance_type):
"""Returns a object which has be bound on subject instance or class."""
assert not self.subject, "The gateway has bound with %r" % self.subject
if instance:
return self.make_from_prototype(instance)
else:
return self.make_from_prototype(instance_type)
def make_from_prototype(self, subject):
"""Creates a bound gateway with current gateway as prototype."""
bound_gateway = copy.copy(self)
bound_gateway.subject = subject
return bound_gateway
class DatabaseGateway(Gateway):
"""A simple database gateway."""
logger = logging.getLogger("corelib.gateway.DatabaseGateway")
def __init__(self, columns, table, pk="id"):
super(DatabaseGateway, self).__init__()
self.table = table
self.pk = pk
self.columns = columns
def get(self, pk):
"""Gets a subject object by primary key value."""
sql = "select {0} from {1} where {2}=%s".format(self.build_columns(),
self.table, self.pk)
query_result = self.execute_query(sql, pk)
if query_result:
return next(self.build_objects(query_result[0]))
def get_all(self, *conditions, **options):
"""Gets a subject object collection.
Example:
>> class User(object):
db = DatabaseGateway(["id", "name", "age"], "users", "id")
..
>> user = User()
>> user.name = "spam"
>> user.age = 110
>> user.db.save()
True
>> user2 = User.db.get_all(orderby=("id", "desc"), limit=1, offset=0)
>> user.name == user2.name
True
>> user.age == user2.age
True
>>
More usage:
* Gets by primary key: `user.db.get(100001)`
=> "select ... where `id` = %s" (100001,)
* Gets by columns values: `user.db.get_all({'name': "abc", 'age': 10},
{'name': "def", 'age': 20})
=> "select ... from users where "
=> "(`name` = %s and `age` = %s) or (`name` = %s and `age` = %s)"
=> ("abc", 10, "def", 20)
* Gets latest 10 users who older than 20 years:
`user.db.get_all({'age>=': 20})`
=> "select ... from users where age >= %s"
=> (20,)
"""
#: user input values
values = []
#: where statement (query conditions)
where_stmt = []
#: order by statement
orderby_stmt = options.pop("orderby", None)
#: limit statement
limit, offset = options.pop("limit", None), options.pop("offset", 0)
#: creates the transposition of condition k-v pairs
#: -> [columns, values]
conditions = (zip(*kv_pairs.items()) for kv_pairs in conditions)
#: builds where statement
for condition_columns, condition_values in conditions:
#: add values to container
values.extend(condition_values)
#: build this "and" group
and_group = ("`{0}` {1} %s".format(*self.build_where_expression(c))
for c in condition_columns)
where_stmt.append(" and ".join(and_group))
where_stmt = " or ".join("({0})".format(stmt) for stmt in where_stmt)
where_stmt = "where {0}".format(where_stmt) if where_stmt else ""
#: builds order by statement
if orderby_stmt:
orderby_column, orderby_order = orderby_stmt
assert orderby_column in self.columns,\
"orderby column must be in %r" % list(self.columns)
assert orderby_order.lower() in ("asc", "desc"),\
"orderby option must be 'asc' or 'desc'"
orderby_stmt = "order by `{0}` {1}".format(orderby_column,
orderby_order.lower())
else:
orderby_stmt = ""
#: builds limit/offset statement
if limit:
limit_stmt = "limit %s offset %s"
values.append(limit)
values.append(offset)
else:
limit_stmt = ""
#: builds full sql statement
sql_suffix_stmt = " ".join([where_stmt, orderby_stmt, limit_stmt])
sql = "select {0} from {1} {2}".format(self.build_columns(),
self.table,
sql_suffix_stmt)
#: executes querying
query_result = self.execute_query(sql, *values)
return list(self.build_objects(*query_result))
def save(self):
#: value of primary key field
pk_value = getattr(self.subject, self.pk, None)
#: names of columns excluded primary key field
columns = [column for column in self.columns
if (pk_value or column != self.pk)
and hasattr(self.subject, column)]
#: values for columns
values = [getattr(self.subject, column) for column in columns]
#: sql statement snips
insert_column_stmt = ", ".join("`{0}`".format(c) for c in columns)
insert_values_stmt = ", ".join("%s" for c in columns)
update_column_stmt = ", ".join("`{0}` = %s".format(c) for c in columns)
#: sql
sql = "INSERT `{0}` ({1}) VALUES ({2}) ON DUPLICATE KEY UPDATE {3};"
sql = sql.format(self.table, insert_column_stmt, insert_values_stmt,
update_column_stmt)
#: bound values
values *= 2
#: execute query
self.execute_query(sql, *values)
#: assigns primary key
if not pk_value:
pk_value = store.get_cursor().connection.insert_id()
setattr(self.subject, self.pk, pk_value)
return self.subject
def is_exists(self, pk):
exists_sql = "select 1 from `{0}` where `{1}` = %s"
exists_sql = exists_sql.format(self.table, self.pk)
exists = self.execute_query(exists_sql, pk)
return bool(exists)
def build_columns(self):
return ", ".join("`%s`" % c for c in self.columns)
def build_objects(self, *query_results):
"""Builds subject object from a query result."""
for query_result in query_results:
yield self.subject(*query_result)
re_where_expression = re.compile(r"^([a-zA-Z0-9_]+)(>=|<=|>|<|=|<>)?$")
def build_where_expression(self, column_name_expression):
matched = self.re_where_expression.match(column_name_expression)
assert matched, "Invalid query condition expression."
column_name, where_operator = matched.groups()
return column_name, (where_operator or "=")
@staticmethod
def execute_query(sql_stmt, *args):
"""Executes the database querying."""
DatabaseGateway.logger.debug("[QUERY]\n\t%s\n\t%r" % (sql_stmt, args))
return store.execute(sql_stmt, args)
#!/usr/bin/env python
#-*- coding:utf-8 -*-
import contextlib
from db.sqlstore import store
from gateway import Gateway, DatabaseGateway
class MockGateway(Gateway):
"""Mock Gateway."""
def __init__(self, arg1, arg2, arg3):
super(MockGateway, self).__init__()
self.arg1 = arg1
self.arg2 = arg2
self.arg3 = arg3
self.attr = "my-attribute"
def get_args(self):
return "%s-%d-%r" % (self.arg1, self.arg2, self.arg3)
def test_binding_gateway():
"""Test for generating a bound gateway instance from prototype."""
def assert_mock_gateway(mock_gateway):
mock_gateway.arg1 == "mock"
mock_gateway.arg2 == 1024
mock_gateway.arg3 is True
mock_gateway.attr == "my-attribute"
mock_gateway.get_args() == "mock-1024-True"
#: common assigning attributes
mock_gateway = MockGateway("mock", 1024, True)
assert_mock_gateway(mock_gateway)
#: test the bound gateway factory, check it is equal to prototype subject
mock_subject = object()
mock_bound_gateway = mock_gateway.make_from_prototype(mock_subject)
assert_mock_gateway(mock_bound_gateway)
assert mock_bound_gateway.subject is mock_subject
class MockSubject(object):
gateway = mock_gateway
#: test the bound gateway factory again but throught descriptor
mock_subject = MockSubject()
assert_mock_gateway(mock_subject.gateway)
assert mock_subject.gateway.subject is mock_subject
assert MockSubject.gateway.subject is MockSubject
class MockDetail(object):
db = DatabaseGateway(["itemid", "title", "price"],
table="dbgw_mock_detail", pk="itemid")
def __init__(self, itemid=None, title=None, price=None):
self.itemid = itemid
self.title = title
self.price = price
def __str__(self):
return "%d-%s-%.2d" % (self.itemid, self.title, self.price)
@contextlib.contextmanager
def mockdb():
"""Create a context for database."""
assert store.is_testing()
cursor = store.get_cursor()
#: create test table
cursor.execute("create table if not exists `dbgw_mock_detail` ("
"`itemid` int primary key AUTO_INCREMENT,"
"`title` varchar(100),"
"`price` decimal(64,10)"
");")
#: clean test table
cursor.execute("delete from `dbgw_mock_detail` where 1 = 1;")
store.commit()
try:
yield store
finally:
#: drop test table after test finished
cursor.execute("drop table `dbgw_mock_detail`;")
store.commit()
def test_get_by_pk():
with mockdb() as store:
#: by pass
store.execute("insert into `dbgw_mock_detail` "
"(`itemid`, `title`, `price`) "
"values (%s, %s, %s);", (1, "test1", 13.45))
store.commit()
detail_found = MockDetail.db.get(1)
#: by failed
store.execute("delete from `dbgw_mock_detail` where `itemid` = 0;")
store.commit()
detail_not_found = MockDetail.db.get(0)
assert detail_found.itemid == 1, detail_found.itemid
assert detail_found.title == "test1", detail_found.title
assert float(detail_found.price) == 13.45, repr(detail_found.price)
assert detail_not_found is None
def test_get_by_condition():
with mockdb() as store:
sql = ("insert into `dbgw_mock_detail` (`itemid`, `title`, `price`) "
"values (%s, %s, %s);")
store.execute(sql, (1, "test1", 45.8))
store.execute(sql, (2, "test2", 12.7))
store.execute(sql, (3, "test3", 100))
store.execute(sql, (4, "test4", 2.43))
store.commit()
rv_none = [MockDetail.db.get_all({'price': 12.7, 'itemid': 1},
{'price>=': 100, 'itemid': 2}),
MockDetail.db.get_all({'price<': 2.43}),
MockDetail.db.get_all({'price>': 100}),
MockDetail.db.get_all({'price>': 45.8, 'price<': 100})]
rv_has_1_2_3 = [MockDetail.db.get_all({'price>=': 12.7})]
rv_has_1_3 = [MockDetail.db.get_all({'price>': 12.7}),
MockDetail.db.get_all({'price>=': 30}),
MockDetail.db.get_all({'price>': 30})]
rv_has_1_2_4 = [MockDetail.db.get_all({'price<': 100}),
MockDetail.db.get_all({'price<=': 45.8})]
def assert_mock_correct(mock):
if mock.itemid == 1:
assert mock.itemid == 1, mock.itemid
assert mock.title == "test1", mock.title
assert float(mock.price) == 45.8, mock.price
elif mock.itemid == 2:
assert mock.itemid == 2, mock.itemid
assert mock.title == "test2", mock.title
assert float(mock.price) == 12.7, mock.price
elif mock.itemid == 3:
assert mock.itemid == 3, mock.itemid
assert mock.title == "test3", mock.title
assert float(mock.price) == 100.0, mock.price
elif mock.itemid == 4:
assert mock.itemid == 4, mock.itemid
assert mock.title == "test4", mock.title
assert float(mock.price) == 2.43, mock.price
else:
assert False, "Invalid calling"
def assert_result_object_correct(results):
"""Checks the itemid and the object attribute is correct or not."""
for result in results:
for mock in result:
assert_mock_correct(mock)
return True
assert all(len(r) == 0 for r in rv_none), rv_none
assert assert_result_object_correct(rv_has_1_2_3)
assert all(set(r.itemid for r in rv) == set([1, 2, 3])
for rv in rv_has_1_2_3), rv_has_1_2_3
assert assert_result_object_correct(rv_has_1_3)
assert all(set(r.itemid for r in rv) == set([1, 3])
for rv in rv_has_1_3), rv_has_1_3
assert assert_result_object_correct(rv_has_1_2_4)
assert all(set(r.itemid for r in rv) == set([1, 2, 4])
for rv in rv_has_1_2_4), rv_has_1_2_4
def test_limit_and_orderby():
with mockdb() as store:
for i in range(100):
sql = ("insert into `dbgw_mock_detail` "
"(`itemid`, `title`, `price`) "
"values (%s, %s, %s);")
id_ = i + 1
store.execute(sql, (id_, "test%d" % id_, id_ + id_ * 0.1))
store.commit()
top_10 = MockDetail.db.get_all(orderby=("itemid", "desc"), limit=10)
top_20 = MockDetail.db.get_all(orderby=("price", "asc"),
limit=10, offset=10)
for num, mock in enumerate(top_10):
assert mock.itemid == 100 - num, mock.itemid
assert float(mock.price) == mock.itemid + mock.itemid * 0.1,\
repr(mock.price)
assert mock.title == "test%d" % mock.itemid
for num, mock in enumerate(top_20):
assert mock.itemid == 11 + num, mock.itemid
assert float(mock.price) == mock.itemid + mock.itemid * 0.1,\
repr(mock.price)
assert mock.title == "test%d" % mock.itemid
def test_is_exists():
with mockdb() as store:
store.execute("insert into `dbgw_mock_detail` (`title`, `price`) "
"values (%s, %s)", ("xxx", 12.35))
itemid = store.get_cursor().connection.insert_id()
store.commit()
assert MockDetail.db.is_exists(itemid), itemid
assert not MockDetail.db.is_exists(itemid + 1), itemid
def test_save():
with mockdb() as store:
#: clean the database
store.execute("delete from `dbgw_mock_detail` where 1=1;")
mock = MockDetail(title="mock-title", price=998.98)
#: test to insert
mock.db.save()
store.commit()
query = store.execute("select title, price from `dbgw_mock_detail` "
"order by `itemid` desc;")
assert len(query) == 1 # only one record in a cleaned table
query_title, query_price = query[0]
assert query_title == "mock-title", query_title
assert float(query_price) == 998.98, repr(query_price)
assert mock.itemid is not None
#: test to update
old_id = mock.itemid # the old_id should be assigned while "insert"
mock.title = "updated-title"
mock.price = 9998.98
mock.db.save()
store.commit()
query = store.execute("select title, price from `dbgw_mock_detail` "
"order by `itemid` desc;")
assert len(query) == 1, len(query) # because this is update, not insert
query_title, query_price = query[0]
assert query_title == "updated-title", query_title
assert float(query_price) == 9998.98, repr(query_price)
#: should not change itemid
assert mock.itemid is not None
assert mock.itemid == old_id, (mock.itemid, old_id)
#: test to insert with provided primary key value
mock.itemid = 10086
mock.db.save()
store.commit()
query = store.execute("select itemid, title, price "
"from `dbgw_mock_detail` "
"order by itemid desc limit 1")
query_id, query_title, query_price = query[0]
assert query_id == 10086
assert query_title == "updated-title", query_title
assert float(query_price) == 9998.98, repr(query_price)
query = store.execute("select count(*) from `dbgw_mock_detail`")
record_count = query[0][0]
assert record_count == 2 # old record has not been removed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment