Skip to content

Instantly share code, notes, and snippets.

@mcm
Created April 12, 2015 23:02
Show Gist options
  • Save mcm/8ca9f2cb2a7d86c1881f to your computer and use it in GitHub Desktop.
Save mcm/8ca9f2cb2a7d86c1881f to your computer and use it in GitHub Desktop.
import pymongo
from . import encoders
from . import hooks
class BaseResource:
def __init__(self):
self.mongo_client = pymongo.MongoClient()
@property
def encoder(self):
return encoders.ObjectIDEncoder()
def get_endpoint_fn(self, endpoint):
return {
"Create": self.on_post,
"Read": self.on_get,
"Update": self.on_put,
"Delete": self.on_delete,
}[endpoint]
def add_hook(self, endpoint, hook, target):
endpoint_fn = self.get_endpoint_fn(endpoint)
# Initialize hooks dictionary if required
if not hasattr(endpoint_fn, "hooks"):
endpoint_fn.hooks = dict()
# Create hook list if necessary and append new hook to it
if not hook in endpoint_fn.hooks:
endpoint_fn.hooks[hook] = list()
endpoint_fn.hooks[hook].append(target)
def run_hooks(self, endpoint, hook, req, resp, *args, **kwargs):
endpoint_fn = self.get_endpoint_fn(endpoint)
for hook_fn in getattr(endpoint_fn, "hooks", list()):
(args, kwargs) = hook_fn(req, resp, *args, **kwargs)
return (args, kwargs)
def serialize_response(self, req):
def rename_id(doc):
doc["id"] = doc["_id"]
del doc["_id"]
return doc
if isinstance(req["data"], pymongo.cursor.Cursor):
req.context["data"] = [rename_id(doc) for doc in req["data"]]
else:
req.context["data"] = rename_id(req.context["data"])
# Pagination output
if "pagination" in req.context:
req.context["data"] = [req.context["pagination"], req.context["data"]]
return self.encoder.encode(req.context["data"])
def on_get(self, req, resp):
if not callable(getattr(self, "ReadEndpoint", None)):
raise NotImplementedError
# Run startup hooks
self.run_hooks("Read", "init", req)
# Trigger the ReadEndpoint
self.run_hooks("Read", "pre_get_data", req)
self.ReadEndpoint(self.mongo_client).get_data(req)
self.run_hooks("Read", "post_get_data", req)
# Set up the response
self.run_hooks("Read", "pre_serialize_response", req)
resp.body = self.serialize_response(req)
self.run_hooks("Read", "post_serialize_response", req)
# Cleanup
self.run_hooks("Read", "shutdown", req, resp)
def on_post(self, req, resp):
if not callable(getattr(self, "CreateEndpoint", None)):
raise NotImplementedError
def on_put(self, req, resp):
if not callable(getattr(self, "UpdateEndpoint", None)):
raise NotImplementedError
def on_delete(self, req, resp):
if not callable(getattr(self, "DeleteEndpoint", None)):
raise NotImplementedError
class FilterMixin:
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.add_hook("Read", "pre_get_data", hooks.filter_queryset)
class PaginationMixin:
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Perform pagination on the queryset object
self.add_hook("Read", "pre_get_data", hooks.paginate_queryset)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment