Last active
September 11, 2018 20:03
-
-
Save commandodev/5108455 to your computer and use it in GitHub Desktop.
Rest traversal in pyramid. With a small example of usage.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from pyramid.view import view_config | |
from sqlalchemy.ext.associationproxy import AssociationProxy | |
from sqlalchemy.ext.declarative import declarative_base | |
from sqlalchemy.orm import scoped_session, sessionmaker, object_mapper, ColumnProperty, SynonymProperty | |
Session = scoped_session(sessionmaker()) | |
class _PrettyPrintBase(object): | |
"""Base mixin for all of our declarative tables | |
.. note:: Don't use this directly it's a mixin to be used with | |
:func:`~sqlalchemy.ext.declarative.declarative_base` | |
""" | |
query = Session.query_property() | |
def __str__(self): | |
return self._pk | |
def __repr__(self): | |
return "<%s: %s>" % (self.__class__.__name__, self) | |
@property | |
def _pk(self): | |
if hasattr(self, '__table__'): | |
om = object_mapper(self) | |
pk_cols = om.primary_key | |
pk_column_keys = [p.key for p in om.iterate_properties | |
if isinstance(p, ColumnProperty) | |
and p.columns[0] in pk_cols] | |
pk_to_val = [(k, getattr(self, k)) for k in pk_column_keys] | |
return ', '.join('%s=%s' % (k, v if v else 'NONE') for k, v in pk_to_val) | |
return "No Table" | |
Base = declarative_base(cls=_PrettyPrintBase) | |
class BaseTraverser(object): | |
base = None | |
def getitem(self, item): | |
raise KeyError | |
def __getitem__(self, item): | |
resource = self.getitem(item) | |
resource.__name__ = item | |
resource.__parent__ = self | |
return resource | |
class DBTraverser(BaseTraverser): | |
def getitem(self, item): | |
model = get_model(self.base, item) | |
return ModelResource(item, model) | |
def model_to_dict(model_inst): | |
"""Generic function to convert a database model instance to a dict | |
This is to enable serialization to json at a later stage | |
""" | |
mapper = object_mapper(model_inst) | |
props = list(mapper.iterate_properties) | |
synonyms = [p for p in props if isinstance(p, SynonymProperty)] | |
synonymed_column_names = [s.name for s in synonyms] | |
keys = [p.key for p in props | |
if isinstance(p, ColumnProperty) | |
and p.key not in synonymed_column_names] +\ | |
[p.key for p in synonyms] | |
association_proxies = [k for k, v in model_inst.__class__.__dict__.items() | |
if isinstance(v, AssociationProxy)] | |
def get_ap_attr(model, attr): | |
try: | |
model_attr = getattr(model, attr) | |
except AttributeError: | |
return None | |
else: | |
try: | |
return attr, model_attr.copy() | |
except AttributeError: | |
return attr, model_attr | |
ap_mapping = [x for x in [get_ap_attr(model_inst, key) for key in association_proxies] if x] | |
return dict([(key, getattr(model_inst, key)) for key in keys] + ap_mapping) | |
def get_model(base, name): | |
"""Look up a table class based on it's name | |
:param base: A :ref:`sqlalchemy:declarative_toplevel` base class | |
:type: Subclass of :class:`_PrettyPrintBase`` | |
:param name: The name of the class | |
""" | |
return base._decl_class_registry[name] | |
class ModelResource(BaseTraverser): | |
def __init__(self, name, Model): | |
self.name = name | |
self.ses = Session() | |
self.Model = Model | |
@property | |
def q(self): | |
return self.ses.query(self.Model) | |
def getitem(self, primary_key): | |
return ItemResource(self.name, self.q.get(int(primary_key))) | |
class ItemResource(BaseTraverser): | |
def __init__(self, model_name, model_instance): | |
self.name = model_name | |
self.model = model_instance | |
def getitem(self, relation_name): | |
try: | |
return RelationResource(getattr(self.model, relation_name)) | |
except AttributeError: | |
raise KeyError("No relation %s" % relation_name) | |
class RelationResource(object): | |
def __init__(self, child_list): | |
self.children = child_list | |
def model_is(model_class): | |
"""Pyramid custom prediacate that matches a specific Model""" | |
def _model_is(context, request): | |
return hasattr(context, "Model") and context.Model is model_class | |
_model_is.__name__ = "model_is_%s" % model_class.__name__ | |
return _model_is | |
@view_config(context=ModelResource, renderer="safe_json") | |
def list_model(context, request): | |
q = context.q | |
filters = request.GET.get("q") | |
if filters: | |
q = q.filter_by(**json.loads(filters)) | |
return [model_to_dict(mkt) for mkt in q.all()] | |
@view_config(context=ItemResource, renderer="safe_json") | |
def model_detail(context, request): | |
return model_to_dict(context.model) | |
@view_config(context=RelationResource, renderer="safe_json") | |
def list_related(context, request): | |
return [model_to_dict(mkt) for mkt in context.children] | |
@view_config(context=DBTraverser, name="models", renderer="safe_json") | |
def models(context, request): | |
return context.base._decl_class_registry.keys() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from pyramid.view import view_config | |
from rest_traversal import Base, DBTraverser, BaseTraverser | |
class MyTraverser(DBTraverser): | |
base = Base | |
class Root(BaseTraverser): | |
__parent__ = None | |
__name__ = None | |
ROUTES = { | |
"db": DBTraverser() | |
} | |
def getitem(self, item): | |
return self.ROUTES[item] | |
root = Root() | |
def app_root_factory(request): | |
return root | |
class MyModel(Base): | |
__tablename__ = "a_table" | |
__table_args__ = ( | |
dict(schema='a_schema', extend_existing=True) | |
) | |
# columns for a_table go here | |
@view_config(context=ModelResource, request_method="POST", renderer="safe_json", | |
custom_predicates=[model_is(MyModel)]) | |
def list_model(context, request): | |
# special logic for MyModel | |
return dict() | |
def main(global_conf, **settings): | |
""" This function returns a WSGI application.""" | |
settings.update(global_conf) | |
engine = engine_from_config(settings) | |
initialize_sql(engine, Base) | |
config = Configurator(settings=settings, root_factory=app_root_factory) | |
config.include("pyramid_jinja2") | |
config.add_view('pyramid.view.append_slash_notfound_view', | |
context='pyramid.httpexceptions.HTTPNotFound') | |
config.scan() | |
return config.make_wsgi_app() | |
def initialize_sql(engine, base): | |
base.metadata.bind = engine |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment