Created
September 18, 2012 04:55
-
-
Save tonyseek/3741335 to your computer and use it in GitHub Desktop.
A utility to give a object a gateway to exchange data with outside sources.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
#-*- coding:utf-8 -*- | |
import re | |
import copy | |
import logging | |
from db.sqlstore import store | |
class Gateway(object): | |
"""A gateway to fetch and store a subject object.""" | |
def __init__(self): | |
self.subject = None | |
def __get__(self, instance, instance_type): | |
"""Returns a object which has be bound on subject instance or class.""" | |
assert not self.subject, "The gateway has bound with %r" % self.subject | |
if instance: | |
return self.make_from_prototype(instance) | |
else: | |
return self.make_from_prototype(instance_type) | |
def make_from_prototype(self, subject): | |
"""Creates a bound gateway with current gateway as prototype.""" | |
bound_gateway = copy.copy(self) | |
bound_gateway.subject = subject | |
return bound_gateway | |
class DatabaseGateway(Gateway): | |
"""A simple database gateway.""" | |
logger = logging.getLogger("corelib.gateway.DatabaseGateway") | |
def __init__(self, columns, table, pk="id"): | |
super(DatabaseGateway, self).__init__() | |
self.table = table | |
self.pk = pk | |
self.columns = columns | |
def get(self, pk): | |
"""Gets a subject object by primary key value.""" | |
sql = "select {0} from {1} where {2}=%s".format(self.build_columns(), | |
self.table, self.pk) | |
query_result = self.execute_query(sql, pk) | |
if query_result: | |
return next(self.build_objects(query_result[0])) | |
def get_all(self, *conditions, **options): | |
"""Gets a subject object collection. | |
Example: | |
>> class User(object): | |
db = DatabaseGateway(["id", "name", "age"], "users", "id") | |
.. | |
>> user = User() | |
>> user.name = "spam" | |
>> user.age = 110 | |
>> user.db.save() | |
True | |
>> user2 = User.db.get_all(orderby=("id", "desc"), limit=1, offset=0) | |
>> user.name == user2.name | |
True | |
>> user.age == user2.age | |
True | |
>> | |
More usage: | |
* Gets by primary key: `user.db.get(100001)` | |
=> "select ... where `id` = %s" (100001,) | |
* Gets by columns values: `user.db.get_all({'name': "abc", 'age': 10}, | |
{'name': "def", 'age': 20}) | |
=> "select ... from users where " | |
=> "(`name` = %s and `age` = %s) or (`name` = %s and `age` = %s)" | |
=> ("abc", 10, "def", 20) | |
* Gets latest 10 users who older than 20 years: | |
`user.db.get_all({'age>=': 20})` | |
=> "select ... from users where age >= %s" | |
=> (20,) | |
""" | |
#: user input values | |
values = [] | |
#: where statement (query conditions) | |
where_stmt = [] | |
#: order by statement | |
orderby_stmt = options.pop("orderby", None) | |
#: limit statement | |
limit, offset = options.pop("limit", None), options.pop("offset", 0) | |
#: creates the transposition of condition k-v pairs | |
#: -> [columns, values] | |
conditions = (zip(*kv_pairs.items()) for kv_pairs in conditions) | |
#: builds where statement | |
for condition_columns, condition_values in conditions: | |
#: add values to container | |
values.extend(condition_values) | |
#: build this "and" group | |
and_group = ("`{0}` {1} %s".format(*self.build_where_expression(c)) | |
for c in condition_columns) | |
where_stmt.append(" and ".join(and_group)) | |
where_stmt = " or ".join("({0})".format(stmt) for stmt in where_stmt) | |
where_stmt = "where {0}".format(where_stmt) if where_stmt else "" | |
#: builds order by statement | |
if orderby_stmt: | |
orderby_column, orderby_order = orderby_stmt | |
assert orderby_column in self.columns,\ | |
"orderby column must be in %r" % list(self.columns) | |
assert orderby_order.lower() in ("asc", "desc"),\ | |
"orderby option must be 'asc' or 'desc'" | |
orderby_stmt = "order by `{0}` {1}".format(orderby_column, | |
orderby_order.lower()) | |
else: | |
orderby_stmt = "" | |
#: builds limit/offset statement | |
if limit: | |
limit_stmt = "limit %s offset %s" | |
values.append(limit) | |
values.append(offset) | |
else: | |
limit_stmt = "" | |
#: builds full sql statement | |
sql_suffix_stmt = " ".join([where_stmt, orderby_stmt, limit_stmt]) | |
sql = "select {0} from {1} {2}".format(self.build_columns(), | |
self.table, | |
sql_suffix_stmt) | |
#: executes querying | |
query_result = self.execute_query(sql, *values) | |
return list(self.build_objects(*query_result)) | |
def save(self): | |
#: value of primary key field | |
pk_value = getattr(self.subject, self.pk, None) | |
#: names of columns excluded primary key field | |
columns = [column for column in self.columns | |
if (pk_value or column != self.pk) | |
and hasattr(self.subject, column)] | |
#: values for columns | |
values = [getattr(self.subject, column) for column in columns] | |
#: sql statement snips | |
insert_column_stmt = ", ".join("`{0}`".format(c) for c in columns) | |
insert_values_stmt = ", ".join("%s" for c in columns) | |
update_column_stmt = ", ".join("`{0}` = %s".format(c) for c in columns) | |
#: sql | |
sql = "INSERT `{0}` ({1}) VALUES ({2}) ON DUPLICATE KEY UPDATE {3};" | |
sql = sql.format(self.table, insert_column_stmt, insert_values_stmt, | |
update_column_stmt) | |
#: bound values | |
values *= 2 | |
#: execute query | |
self.execute_query(sql, *values) | |
#: assigns primary key | |
if not pk_value: | |
pk_value = store.get_cursor().connection.insert_id() | |
setattr(self.subject, self.pk, pk_value) | |
return self.subject | |
def is_exists(self, pk): | |
exists_sql = "select 1 from `{0}` where `{1}` = %s" | |
exists_sql = exists_sql.format(self.table, self.pk) | |
exists = self.execute_query(exists_sql, pk) | |
return bool(exists) | |
def build_columns(self): | |
return ", ".join("`%s`" % c for c in self.columns) | |
def build_objects(self, *query_results): | |
"""Builds subject object from a query result.""" | |
for query_result in query_results: | |
yield self.subject(*query_result) | |
re_where_expression = re.compile(r"^([a-zA-Z0-9_]+)(>=|<=|>|<|=|<>)?$") | |
def build_where_expression(self, column_name_expression): | |
matched = self.re_where_expression.match(column_name_expression) | |
assert matched, "Invalid query condition expression." | |
column_name, where_operator = matched.groups() | |
return column_name, (where_operator or "=") | |
@staticmethod | |
def execute_query(sql_stmt, *args): | |
"""Executes the database querying.""" | |
DatabaseGateway.logger.debug("[QUERY]\n\t%s\n\t%r" % (sql_stmt, args)) | |
return store.execute(sql_stmt, args) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
#-*- coding:utf-8 -*- | |
import contextlib | |
from db.sqlstore import store | |
from gateway import Gateway, DatabaseGateway | |
class MockGateway(Gateway): | |
"""Mock Gateway.""" | |
def __init__(self, arg1, arg2, arg3): | |
super(MockGateway, self).__init__() | |
self.arg1 = arg1 | |
self.arg2 = arg2 | |
self.arg3 = arg3 | |
self.attr = "my-attribute" | |
def get_args(self): | |
return "%s-%d-%r" % (self.arg1, self.arg2, self.arg3) | |
def test_binding_gateway(): | |
"""Test for generating a bound gateway instance from prototype.""" | |
def assert_mock_gateway(mock_gateway): | |
mock_gateway.arg1 == "mock" | |
mock_gateway.arg2 == 1024 | |
mock_gateway.arg3 is True | |
mock_gateway.attr == "my-attribute" | |
mock_gateway.get_args() == "mock-1024-True" | |
#: common assigning attributes | |
mock_gateway = MockGateway("mock", 1024, True) | |
assert_mock_gateway(mock_gateway) | |
#: test the bound gateway factory, check it is equal to prototype subject | |
mock_subject = object() | |
mock_bound_gateway = mock_gateway.make_from_prototype(mock_subject) | |
assert_mock_gateway(mock_bound_gateway) | |
assert mock_bound_gateway.subject is mock_subject | |
class MockSubject(object): | |
gateway = mock_gateway | |
#: test the bound gateway factory again but throught descriptor | |
mock_subject = MockSubject() | |
assert_mock_gateway(mock_subject.gateway) | |
assert mock_subject.gateway.subject is mock_subject | |
assert MockSubject.gateway.subject is MockSubject | |
class MockDetail(object): | |
db = DatabaseGateway(["itemid", "title", "price"], | |
table="dbgw_mock_detail", pk="itemid") | |
def __init__(self, itemid=None, title=None, price=None): | |
self.itemid = itemid | |
self.title = title | |
self.price = price | |
def __str__(self): | |
return "%d-%s-%.2d" % (self.itemid, self.title, self.price) | |
@contextlib.contextmanager | |
def mockdb(): | |
"""Create a context for database.""" | |
assert store.is_testing() | |
cursor = store.get_cursor() | |
#: create test table | |
cursor.execute("create table if not exists `dbgw_mock_detail` (" | |
"`itemid` int primary key AUTO_INCREMENT," | |
"`title` varchar(100)," | |
"`price` decimal(64,10)" | |
");") | |
#: clean test table | |
cursor.execute("delete from `dbgw_mock_detail` where 1 = 1;") | |
store.commit() | |
try: | |
yield store | |
finally: | |
#: drop test table after test finished | |
cursor.execute("drop table `dbgw_mock_detail`;") | |
store.commit() | |
def test_get_by_pk(): | |
with mockdb() as store: | |
#: by pass | |
store.execute("insert into `dbgw_mock_detail` " | |
"(`itemid`, `title`, `price`) " | |
"values (%s, %s, %s);", (1, "test1", 13.45)) | |
store.commit() | |
detail_found = MockDetail.db.get(1) | |
#: by failed | |
store.execute("delete from `dbgw_mock_detail` where `itemid` = 0;") | |
store.commit() | |
detail_not_found = MockDetail.db.get(0) | |
assert detail_found.itemid == 1, detail_found.itemid | |
assert detail_found.title == "test1", detail_found.title | |
assert float(detail_found.price) == 13.45, repr(detail_found.price) | |
assert detail_not_found is None | |
def test_get_by_condition(): | |
with mockdb() as store: | |
sql = ("insert into `dbgw_mock_detail` (`itemid`, `title`, `price`) " | |
"values (%s, %s, %s);") | |
store.execute(sql, (1, "test1", 45.8)) | |
store.execute(sql, (2, "test2", 12.7)) | |
store.execute(sql, (3, "test3", 100)) | |
store.execute(sql, (4, "test4", 2.43)) | |
store.commit() | |
rv_none = [MockDetail.db.get_all({'price': 12.7, 'itemid': 1}, | |
{'price>=': 100, 'itemid': 2}), | |
MockDetail.db.get_all({'price<': 2.43}), | |
MockDetail.db.get_all({'price>': 100}), | |
MockDetail.db.get_all({'price>': 45.8, 'price<': 100})] | |
rv_has_1_2_3 = [MockDetail.db.get_all({'price>=': 12.7})] | |
rv_has_1_3 = [MockDetail.db.get_all({'price>': 12.7}), | |
MockDetail.db.get_all({'price>=': 30}), | |
MockDetail.db.get_all({'price>': 30})] | |
rv_has_1_2_4 = [MockDetail.db.get_all({'price<': 100}), | |
MockDetail.db.get_all({'price<=': 45.8})] | |
def assert_mock_correct(mock): | |
if mock.itemid == 1: | |
assert mock.itemid == 1, mock.itemid | |
assert mock.title == "test1", mock.title | |
assert float(mock.price) == 45.8, mock.price | |
elif mock.itemid == 2: | |
assert mock.itemid == 2, mock.itemid | |
assert mock.title == "test2", mock.title | |
assert float(mock.price) == 12.7, mock.price | |
elif mock.itemid == 3: | |
assert mock.itemid == 3, mock.itemid | |
assert mock.title == "test3", mock.title | |
assert float(mock.price) == 100.0, mock.price | |
elif mock.itemid == 4: | |
assert mock.itemid == 4, mock.itemid | |
assert mock.title == "test4", mock.title | |
assert float(mock.price) == 2.43, mock.price | |
else: | |
assert False, "Invalid calling" | |
def assert_result_object_correct(results): | |
"""Checks the itemid and the object attribute is correct or not.""" | |
for result in results: | |
for mock in result: | |
assert_mock_correct(mock) | |
return True | |
assert all(len(r) == 0 for r in rv_none), rv_none | |
assert assert_result_object_correct(rv_has_1_2_3) | |
assert all(set(r.itemid for r in rv) == set([1, 2, 3]) | |
for rv in rv_has_1_2_3), rv_has_1_2_3 | |
assert assert_result_object_correct(rv_has_1_3) | |
assert all(set(r.itemid for r in rv) == set([1, 3]) | |
for rv in rv_has_1_3), rv_has_1_3 | |
assert assert_result_object_correct(rv_has_1_2_4) | |
assert all(set(r.itemid for r in rv) == set([1, 2, 4]) | |
for rv in rv_has_1_2_4), rv_has_1_2_4 | |
def test_limit_and_orderby(): | |
with mockdb() as store: | |
for i in range(100): | |
sql = ("insert into `dbgw_mock_detail` " | |
"(`itemid`, `title`, `price`) " | |
"values (%s, %s, %s);") | |
id_ = i + 1 | |
store.execute(sql, (id_, "test%d" % id_, id_ + id_ * 0.1)) | |
store.commit() | |
top_10 = MockDetail.db.get_all(orderby=("itemid", "desc"), limit=10) | |
top_20 = MockDetail.db.get_all(orderby=("price", "asc"), | |
limit=10, offset=10) | |
for num, mock in enumerate(top_10): | |
assert mock.itemid == 100 - num, mock.itemid | |
assert float(mock.price) == mock.itemid + mock.itemid * 0.1,\ | |
repr(mock.price) | |
assert mock.title == "test%d" % mock.itemid | |
for num, mock in enumerate(top_20): | |
assert mock.itemid == 11 + num, mock.itemid | |
assert float(mock.price) == mock.itemid + mock.itemid * 0.1,\ | |
repr(mock.price) | |
assert mock.title == "test%d" % mock.itemid | |
def test_is_exists(): | |
with mockdb() as store: | |
store.execute("insert into `dbgw_mock_detail` (`title`, `price`) " | |
"values (%s, %s)", ("xxx", 12.35)) | |
itemid = store.get_cursor().connection.insert_id() | |
store.commit() | |
assert MockDetail.db.is_exists(itemid), itemid | |
assert not MockDetail.db.is_exists(itemid + 1), itemid | |
def test_save(): | |
with mockdb() as store: | |
#: clean the database | |
store.execute("delete from `dbgw_mock_detail` where 1=1;") | |
mock = MockDetail(title="mock-title", price=998.98) | |
#: test to insert | |
mock.db.save() | |
store.commit() | |
query = store.execute("select title, price from `dbgw_mock_detail` " | |
"order by `itemid` desc;") | |
assert len(query) == 1 # only one record in a cleaned table | |
query_title, query_price = query[0] | |
assert query_title == "mock-title", query_title | |
assert float(query_price) == 998.98, repr(query_price) | |
assert mock.itemid is not None | |
#: test to update | |
old_id = mock.itemid # the old_id should be assigned while "insert" | |
mock.title = "updated-title" | |
mock.price = 9998.98 | |
mock.db.save() | |
store.commit() | |
query = store.execute("select title, price from `dbgw_mock_detail` " | |
"order by `itemid` desc;") | |
assert len(query) == 1, len(query) # because this is update, not insert | |
query_title, query_price = query[0] | |
assert query_title == "updated-title", query_title | |
assert float(query_price) == 9998.98, repr(query_price) | |
#: should not change itemid | |
assert mock.itemid is not None | |
assert mock.itemid == old_id, (mock.itemid, old_id) | |
#: test to insert with provided primary key value | |
mock.itemid = 10086 | |
mock.db.save() | |
store.commit() | |
query = store.execute("select itemid, title, price " | |
"from `dbgw_mock_detail` " | |
"order by itemid desc limit 1") | |
query_id, query_title, query_price = query[0] | |
assert query_id == 10086 | |
assert query_title == "updated-title", query_title | |
assert float(query_price) == 9998.98, repr(query_price) | |
query = store.execute("select count(*) from `dbgw_mock_detail`") | |
record_count = query[0][0] | |
assert record_count == 2 # old record has not been removed |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment