from pyramid.decorator import reify

from pyramid_traversalwrapper import LocationProxy

class traversable_attrs(object):
    """ A decorator that adds a "wrap" attribute to the given class
    in the form of a dict-like class that does item lookup based on
    the attrs given.
    """

    def __init__(self, **kwargs):
        self.iterable_attrs = kwargs

    def __call__(self, cls):
        class AttrIterableWrapper(LocationProxy):
            iterable_attrs = self.iterable_attrs

            def __getitem__(self, k):
                newcls = self.iterable_attrs[k]
                return newcls(parent=self, name=k)

        cls.__wrapper__ = AttrIterableWrapper
        return cls

class TraversalMixin(object):
    """ Simple mixin for traversal providing a common constructor
    and a way to get the request object.
    """
    __parent__ = None
    __name__ = None

    def __init__(self, request=None, parent=None, name=None):
        self.__parent__ = parent
        self.__name__ = name

        if request is not None:
            self.request = request

    @reify
    def request(self):
        if self.__parent__:
            return self.__parent__.request

class ModelContainer(TraversalMixin):
    """ Base class for a container of SQLAlchemy model objects.

    Use the class variable __model__ to specify your SQLA model.
    If the model does not have a distinct primary key or you wish to
    use a different column for loading/naming objects.
    
    For example:
        class UsersContainer(ModelContainer):
            __model__ = User
            __lookup__ = 'login'

    The database session is assumed to be accessible through
    ``request.db``. If this is not the case, you should override the
    db property with your own way of accessing a session.
    """
    __wrapper__ = LocationProxy
    __lookup__ = None

    @reify
    def db(self):
        return self.request.db

    def add(self, **kwargs):
        m = self.__model__(**kwargs)
        self.db.add(m)
        return m

    def __getitem__(self, k):
        if self.__lookup__ is None:
            obj = self.db.query(self.__model__).get(k)
        else:
            attr = getattr(self.__model__, self.__lookup__)
            obj = self.db.query(self.__model__).filter(attr==k).first()
        if obj is None:
            raise KeyError(k)
        return self.__wrapper__(obj, self, k)

    def __iter__(self):
        return iter(self.db.query(self.__model__))