-
-
Save a740122/ac6b1efa4e441108562de9d0cf6e8a1e to your computer and use it in GitHub Desktop.
Serialize SQLAlchemy Model to dictionary (for JSON output) and update Model from dictionary attributes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import uuid | |
import wtforms_json | |
from sqlalchemy import not_ | |
from sqlalchemy.dialects.postgresql import UUID | |
from wtforms import Form | |
from wtforms.fields import FormField, FieldList | |
from wtforms.validators import Length | |
from flask import current_app as app | |
from flask import request, json, jsonify, abort | |
from flask.ext.sqlalchemy import SQLAlchemy | |
db = SQLAlchemy(app) | |
wtforms_json.init() | |
class Model(db.Model): | |
"""Base SQLAlchemy Model for automatic serialization and | |
deserialization of columns and nested relationships. | |
Usage:: | |
>>> class User(Model): | |
>>> id = db.Column(db.Integer(), primary_key=True) | |
>>> email = db.Column(db.String(), index=True) | |
>>> name = db.Column(db.String()) | |
>>> password = db.Column(db.String()) | |
>>> posts = db.relationship('Post', backref='user', lazy='dynamic') | |
>>> ... | |
>>> default_fields = ['email', 'name'] | |
>>> hidden_fields = ['password'] | |
>>> readonly_fields = ['email', 'password'] | |
>>> | |
>>> class Post(Model): | |
>>> id = db.Column(db.Integer(), primary_key=True) | |
>>> user_id = db.Column(db.String(), db.ForeignKey('user.id'), nullable=False) | |
>>> title = db.Column(db.String()) | |
>>> ... | |
>>> default_fields = ['title'] | |
>>> readonly_fields = ['user_id'] | |
>>> | |
>>> model = User(email='john@localhost') | |
>>> db.session.add(model) | |
>>> db.session.commit() | |
>>> | |
>>> # update name and create a new post | |
>>> validated_input = {'name': 'John', 'posts': [{'title':'My First Post'}]} | |
>>> model.set_columns(**validated_input) | |
>>> db.session.commit() | |
>>> | |
>>> print(model.to_dict(show=['password', 'posts'])) | |
>>> {u'email': u'john@localhost', u'posts': [{u'id': 1, u'title': u'My First Post'}], u'name': u'John', u'id': 1} | |
""" | |
__abstract__ = True | |
# Stores changes made to this model's attributes. Can be retrieved | |
# with model.changes | |
_changes = {} | |
def __init__(self, **kwargs): | |
kwargs['_force'] = True | |
self._set_columns(**kwargs) | |
def _set_columns(self, **kwargs): | |
force = kwargs.get('_force') | |
readonly = [] | |
if hasattr(self, 'readonly_fields'): | |
readonly = self.readonly_fields | |
if hasattr(self, 'hidden_fields'): | |
readonly += self.hidden_fields | |
readonly += [ | |
'id', | |
'created', | |
'updated', | |
'modified', | |
'created_at', | |
'updated_at', | |
'modified_at', | |
] | |
changes = {} | |
columns = self.__table__.columns.keys() | |
relationships = self.__mapper__.relationships.keys() | |
for key in columns: | |
allowed = True if force or key not in readonly else False | |
exists = True if key in kwargs else False | |
if allowed and exists: | |
val = getattr(self, key) | |
if val != kwargs[key]: | |
changes[key] = {'old': val, 'new': kwargs[key]} | |
setattr(self, key, kwargs[key]) | |
for rel in relationships: | |
allowed = True if force or rel not in readonly else False | |
exists = True if rel in kwargs else False | |
if allowed and exists: | |
is_list = self.__mapper__.relationships[rel].uselist | |
if is_list: | |
valid_ids = [] | |
query = getattr(self, rel) | |
cls = self.__mapper__.relationships[rel].argument() | |
for item in kwargs[rel]: | |
if 'id' in item and query.filter_by(id=item['id']).limit(1).count() == 1: | |
obj = cls.query.filter_by(id=item['id']).first() | |
col_changes = obj.set_columns(**item) | |
if col_changes: | |
col_changes['id'] = str(item['id']) | |
if rel in changes: | |
changes[rel].append(col_changes) | |
else: | |
changes.update({rel: [col_changes]}) | |
valid_ids.append(str(item['id'])) | |
else: | |
col = cls() | |
col_changes = col.set_columns(**item) | |
query.append(col) | |
db.session.flush() | |
if col_changes: | |
col_changes['id'] = str(col.id) | |
if rel in changes: | |
changes[rel].append(col_changes) | |
else: | |
changes.update({rel: [col_changes]}) | |
valid_ids.append(str(col.id)) | |
# delete related rows that were not in kwargs[rel] | |
for item in query.filter(not_(cls.id.in_(valid_ids))).all(): | |
col_changes = { | |
'id': str(item.id), | |
'deleted': True, | |
} | |
if rel in changes: | |
changes[rel].append(col_changes) | |
else: | |
changes.update({rel: [col_changes]}) | |
db.session.delete(item) | |
else: | |
val = getattr(self, rel) | |
if self.__mapper__.relationships[rel].query_class is not None: | |
if val is not None: | |
col_changes = val.set_columns(**kwargs[rel]) | |
if col_changes: | |
changes.update({rel: col_changes}) | |
else: | |
if val != kwargs[rel]: | |
setattr(self, rel, kwargs[rel]) | |
changes[rel] = {'old': val, 'new': kwargs[rel]} | |
return changes | |
def set_columns(self, **kwargs): | |
self._changes = self._set_columns(**kwargs) | |
if 'modified' in self.__table__.columns: | |
self.modified = datetime.utcnow() | |
if 'updated' in self.__table__.columns: | |
self.updated = datetime.utcnow() | |
if 'modified_at' in self.__table__.columns: | |
self.modified_at = datetime.utcnow() | |
if 'updated_at' in self.__table__.columns: | |
self.updated_at = datetime.utcnow() | |
return self._changes | |
@property | |
def changes(self): | |
return self._changes | |
def reset_changes(self): | |
self._changes = {} | |
def to_dict(self, show=None, hide=None, path=None, show_all=None): | |
""" Return a dictionary representation of this model. | |
""" | |
if not show: | |
show = [] | |
if not hide: | |
hide = [] | |
hidden = [] | |
if hasattr(self, 'hidden_fields'): | |
hidden = self.hidden_fields | |
default = [] | |
if hasattr(self, 'default_fields'): | |
default = self.default_fields | |
ret_data = {} | |
if not path: | |
path = self.__tablename__.lower() | |
def prepend_path(item): | |
item = item.lower() | |
if item.split('.', 1)[0] == path: | |
return item | |
if len(item) == 0: | |
return item | |
if item[0] != '.': | |
item = '.%s' % item | |
item = '%s%s' % (path, item) | |
return item | |
show[:] = [prepend_path(x) for x in show] | |
hide[:] = [prepend_path(x) for x in hide] | |
columns = self.__table__.columns.keys() | |
relationships = self.__mapper__.relationships.keys() | |
properties = dir(self) | |
for key in columns: | |
check = '%s.%s' % (path, key) | |
if check in hide or key in hidden: | |
continue | |
if show_all or key is 'id' or check in show or key in default: | |
ret_data[key] = getattr(self, key) | |
for key in relationships: | |
check = '%s.%s' % (path, key) | |
if check in hide or key in hidden: | |
continue | |
if show_all or check in show or key in default: | |
hide.append(check) | |
is_list = self.__mapper__.relationships[key].uselist | |
if is_list: | |
ret_data[key] = [] | |
for item in getattr(self, key): | |
ret_data[key].append(item.to_dict( | |
show=show, | |
hide=hide, | |
path=('%s.%s' % (path, key.lower())), | |
show_all=show_all, | |
)) | |
else: | |
if self.__mapper__.relationships[key].query_class is not None: | |
ret_data[key] = getattr(self, key).to_dict( | |
show=show, | |
hide=hide, | |
path=('%s.%s' % (path, key.lower())), | |
show_all=show_all, | |
) | |
else: | |
ret_data[key] = getattr(self, key) | |
for key in list(set(properties) - set(columns) - set(relationships)): | |
if key.startswith('_'): | |
continue | |
check = '%s.%s' % (path, key) | |
if check in hide or key in hidden: | |
continue | |
if show_all or check in show or key in default: | |
val = getattr(self, key) | |
try: | |
ret_data[key] = json.loads(json.dumps(val)) | |
except: | |
pass | |
return ret_data | |
class User(Model): | |
id = db.Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) | |
first_name = db.Column(db.String(120)) | |
last_name = db.Column(db.String(120)) | |
posts = db.relationship('Post', backref='user', lazy='dynamic') | |
class Post(Model): | |
id = db.Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) | |
user_id = db.Column(UUID(as_uuid=True), db.ForeignKey('user.id'), nullable=False) | |
title = db.Column(db.String(200)) | |
text = db.Column(db.String()) | |
class PostForm(Form): | |
title = StringField(validators=[Length(max=200)]) | |
text = StringField() | |
class UserForm(Form): | |
first_name = StringField(validators=[Length(max=120)]) | |
last_name = StringField(validators=[Length(max=120)]) | |
posts = FieldList(FormField(PostForm)) | |
def requested_columns(request): | |
show = request.args.get('show', None) | |
if not show: | |
return [] | |
return show.split(',') | |
@app.route('/users/<string:user_id>', methods=['GET']) | |
def read_user(user_id): | |
# get user from database | |
user = User.query.filter_by(id=user_id).first() | |
if user is None: | |
abort(404) | |
# return user as json | |
show = requested_columns(request) | |
return jsonify(data=user.to_dict(show=show)) | |
@app.route('/users/<string:user_id>', methods=['PUT']) | |
def update_user(user_id): | |
# get user from database | |
user = User.query.filter_by(id=user_id).first() | |
if user is None: | |
abort(404) | |
input_data = request.get_json(force=True) | |
if not isinstance(input_data, dict): | |
return jsonify(error='Request data must be a JSON Object'), 400 | |
# validate json user input using WTForms-JSON | |
form = UserForm.from_json(input_data) | |
if not form.validate(): | |
return jsonify(errors=form.errors), 400 | |
# update user in database | |
user.set_columns(**form.patch_data) | |
db.session.commit() | |
# return user as json | |
show = requested_columns(request) | |
return jsonify(data=user.to_dict(show=show)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment