|
import six |
|
from peewee import Model, FieldAccessor |
|
from playhouse.postgres_ext import BinaryJSONField |
|
|
|
|
|
class GForeignKeyAccessor(FieldAccessor): |
|
def get_rel_instance(self, instance): |
|
value = instance.__data__.get(self.name) |
|
model_name = value['model'] |
|
rel_model = self.field.allowed_types[model_name] |
|
rel_field_name = value.get('rel_field') or self.field.get_ref_field(model_name).name |
|
rel_field = getattr(rel_model, rel_field_name) |
|
|
|
if value is not None or self.name in instance.__rel__: |
|
if self.name not in instance.__rel__: |
|
obj = rel_model.get(rel_field == value.get('pk', value.get('id'))) |
|
instance.__rel__[self.name] = obj |
|
return instance.__rel__[self.name] |
|
elif not self.field.null: |
|
raise rel_model.DoesNotExist |
|
return value |
|
|
|
def __get__(self, instance, instance_type=None): |
|
if instance is not None: |
|
return self.get_rel_instance(instance) |
|
return self.field |
|
|
|
def __set__(self, instance, obj): |
|
if isinstance(obj, tuple(self.field.allowed_types.values())): |
|
# e.g. fm.subject = compound |
|
instance.__data__[self.name] = self.field.db_value(obj).adapted |
|
instance.__rel__[self.name] = obj |
|
elif isinstance(obj, six.string_types): |
|
# e.g. fm.subject = 'Compound/7' |
|
data = self.field.db_value(obj).adapted |
|
prev_value = instance.__data__.get(self.name) |
|
instance.__data__[self.name] = data |
|
if data != prev_value and self.name in instance.__rel__: |
|
del instance.__rel__[self.name] |
|
else: |
|
# while loading from db |
|
# reference: http://initd.org/psycopg/docs/extras.html#json-adaptation |
|
prev_value = instance.__data__.get(self.name) |
|
instance.__data__[self.name] = obj |
|
if obj != prev_value and self.name in instance.__rel__: |
|
del instance.__rel__[self.name] |
|
|
|
instance._dirty.add(self.name) |
|
|
|
|
|
class GForeignKeyField(BinaryJSONField): |
|
accessor_class = GForeignKeyAccessor |
|
|
|
def __init__(self, allowed_types=None, *args, **kwargs): |
|
super(GForeignKeyField, self).__init__(*args, **kwargs) |
|
self.allowed_types = {m.__name__: m for m in allowed_types or []} |
|
|
|
def get_ref_field(self, model): |
|
if isinstance(model, six.string_types): |
|
if model not in self.allowed_types: |
|
raise TypeError('{} not in in {}'.format(model, self.allowed_types)) |
|
model = self.allowed_types[model] |
|
return getattr(model, 'id') |
|
|
|
def db_value(self, value): |
|
if isinstance(value, Model): |
|
model = value._meta.model |
|
ref_field = self.get_ref_field(model) |
|
value = { |
|
'model': model.__name__, |
|
'pk': getattr(value, ref_field.name), |
|
} |
|
elif isinstance(value, six.string_types): |
|
model_name, pk = value.split('/') |
|
ref_field = self.get_ref_field(model_name) |
|
value = { |
|
'model': model_name, |
|
'pk': ref_field.db_value(pk), |
|
} |
|
|
|
return super(GForeignKeyField, self).db_value(value) |
|
|
|
def __eq__(self, rhs): |
|
return self.contains(self.db_value(rhs).adapted) |
|
|
|
def __ne__(self, rhs): |
|
return ~(self.__eq__(rhs)) |