Created
February 12, 2010 17:47
-
-
Save showyou/302792 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # -*- coding:utf-8 -*- | |
| import hashlib | |
| from sqlalchemy import create_engine, MetaData, Table, Column, types | |
| from sqlalchemy.orm import sessionmaker, scoped_session, mapper as sqla_mapper | |
| from sqlalchemy.orm.query import Query | |
| from sqlalchemy.sql import visitors | |
| from werkzeug.contrib.cache import MemcachedCache | |
| #ToDo: | |
| # テストコード | |
| # モデルを追加して,読み出す.まずキャッシュが効いてるか.次にtimeout後に | |
| # キャッシュが効かなくなってるか | |
| class CacheingQuery(Query): | |
| _timeout = 60 | |
| _cache_engine = None | |
| def use_cache(self): | |
| return True | |
| #ToDo:あとで検討したほうがいいかも | |
| def get_cache_engine(self): | |
| # CacheingQueryクラスのスコープが怪しいので注意 | |
| if CacheingQuery._cache_engine == None: | |
| CacheingQuery._cache_engine = MemcachedCache(['127.0.0.1:11211']) | |
| return CacheingQuery._cache_engine | |
| def _get_cache_key(self): | |
| key = hashlib.md5(str(self)).hexdigest() | |
| args = _params_from_query(self) | |
| args_key = " ".join([str(x) for x in args]) | |
| key = key + "_" + args_key | |
| key = key.replace(' ', '_') | |
| return key | |
| def __iter__(self): | |
| cache_key = self._get_cache_key() | |
| cache_engine = self.get_cache_engine() | |
| if not self.use_cache(): | |
| cache_engine.delete(cache_key) | |
| ret = cache_engine.get(cache_key) | |
| if ret == None: | |
| ret = list(Query.__iter__(self)) | |
| cache_engine.set(cache_key, ret, timeout = self._timeout) | |
| return iter(self.session.merge(x, dont_load=True) for x in ret) | |
| #return Query.__iter__(self) | |
| def _params_from_query(query): | |
| v = [] | |
| def visit_bindparam(bind): | |
| value = query._params.get(bind.key, bind.value) | |
| if callable(value): | |
| value = value() | |
| k = bind.params() | |
| v.append("%s:%s" % (k, value)) | |
| for obj in query._from_obj: | |
| _criterion = obj.onclause | |
| if _criterion is not None: | |
| visitors.traverse(_criterion, {}, {'bindparam': visit_bindparam} ) | |
| if query._criterion is not None: | |
| visitors.traverse(query._criterion, {}, {'bindparam': visit_bindparam} ) | |
| return v | |
| metadata = MetaData() | |
| testmodel = Table("test", metadata, | |
| Column('id', types.Integer, primary_key=True), | |
| Column('name', types.Unicode(32)) | |
| ) | |
| class TestModel(object): | |
| pass | |
| def init(): | |
| engine = create_engine('sqlite:///:memory:', echo=True) | |
| session = scoped_session(sessionmaker(engine, | |
| autocommit = False, | |
| query_cls = CacheingQuery), ) | |
| sqla_mapper(TestModel, testmodel) | |
| metadata.create_all(bind=engine) | |
| return session |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment