Created
November 16, 2009 22:35
-
-
Save storborg/236385 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| diff --git a/lib/sqlalchemy/ext/associationproxy.py b/lib/sqlalchemy/ext/associationproxy.py | |
| index e126fe6..1862449 100644 | |
| --- a/lib/sqlalchemy/ext/associationproxy.py | |
| +++ b/lib/sqlalchemy/ext/associationproxy.py | |
| @@ -266,6 +266,33 @@ class AssociationProxy(object): | |
| 'no proxy_bulk_set supplied for custom ' | |
| 'collection_class implementation') | |
| + def _proxy_filter_operator(name): | |
| + def _wrapped_operator(self, object): | |
| + """ | |
| + Proxy the relational filter query from this object to the underlying | |
| + object. For example, if the following is defined: | |
| + | |
| + tags = association_proxy('tag_objects', 'name') | |
| + | |
| + The filter operator: | |
| + | |
| + Page.tags.any('foo') | |
| + | |
| + Becomes: | |
| + | |
| + Page.tag_objects.any(name='foo') | |
| + """ | |
| + coll = getattr(self.owning_class, self.target_collection) | |
| + return getattr(coll, name)(**{self.value_attr: object}) | |
| + return _wrapped_operator | |
| + | |
| + any = _proxy_filter_operator('any') | |
| + has = _proxy_filter_operator('has') | |
| + contains = _proxy_filter_operator('contains') | |
| + __eq__ = _proxy_filter_operator('__eq__') | |
| + __ne__ = _proxy_filter_operator('__ne__') | |
| + | |
| + | |
| class _lazy_collection(object): | |
| def __init__(self, obj, target): | |
| self.ref = weakref.ref(obj) | |
| diff --git a/test/ext/test_associationproxy.py b/test/ext/test_associationproxy.py | |
| index 4a57752..2bbcf29 100644 | |
| --- a/test/ext/test_associationproxy.py | |
| +++ b/test/ext/test_associationproxy.py | |
| @@ -46,6 +46,7 @@ class Parent(object): | |
| self.name = name | |
| class Child(object): | |
| + pops = association_proxy('parent', 'name') | |
| def __init__(self, name): | |
| self.name = name | |
| @@ -946,4 +947,61 @@ class PickleKeyFunc(object): | |
| self.name = name | |
| def __call__(self, obj): | |
| - return getattr(obj, self.name) | |
| \ No newline at end of file | |
| + return getattr(obj, self.name) | |
| + | |
| + | |
| +class FilterProxyTest(TestBase): | |
| + def setup(self): | |
| + | |
| + metadata = MetaData(testing.db) | |
| + | |
| + parents_table = Table('Parent', metadata, | |
| + Column('id', Integer, primary_key=True, | |
| + test_needs_autoincrement=True), | |
| + Column('name', String(128))) | |
| + children_table = Table('Children', metadata, | |
| + Column('id', Integer, primary_key=True, | |
| + test_needs_autoincrement=True), | |
| + Column('parent_id', Integer, | |
| + ForeignKey('Parent.id')), | |
| + Column('foo', String(128)), | |
| + Column('name', String(128))) | |
| + | |
| + mapper(Parent, parents_table, | |
| + properties={'children': orm.relation(Child)}) | |
| + mapper(Child, children_table, | |
| + properties={'parent': orm.relation(Parent)}) | |
| + metadata.create_all() | |
| + | |
| + self.metadata = metadata | |
| + self.session = create_session() | |
| + | |
| + p1 = Parent('p1') | |
| + p1.kids.extend(['foo', 'bar', 'baz']) | |
| + p2 = Parent('p2') | |
| + p2.kids.extend(['baz']) | |
| + p3 = Parent('p3') | |
| + p3.kids.extend(['foo', 'quux']) | |
| + p4 = Parent('p4') | |
| + p4.kids.extend(['foo', 'baz']) | |
| + | |
| + self.session.add_all([p1, p2, p3, p4]) | |
| + self.session.flush() | |
| + | |
| + def teardown(self): | |
| + self.metadata.drop_all() | |
| + clear_mappers() | |
| + | |
| + def test_filter_any(self): | |
| + q_proxy = self.session.query(Parent).\ | |
| + filter(Parent.kids.any('foo')).all() | |
| + q_direct = self.session.query(Parent).\ | |
| + filter(Parent.children.any(name='foo')).all() | |
| + assert set(q_proxy) == set(q_direct) | |
| + | |
| + def test_filter_has(self): | |
| + q_proxy = self.session.query(Child).\ | |
| + filter(Child.pops.has('p1')).all() | |
| + q_direct = self.session.query(Child).\ | |
| + filter(Child.parent.has(name='p1')).all() | |
| + assert set(q_proxy) == set(q_direct) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment