Skip to content

Instantly share code, notes, and snippets.

@nazarewk
Created February 16, 2017 14:12
Show Gist options
  • Save nazarewk/6bcf8eefc03e7bc9f7fbd781a500aad4 to your computer and use it in GitHub Desktop.
Save nazarewk/6bcf8eefc03e7bc9f7fbd781a500aad4 to your computer and use it in GitHub Desktop.
Integrating elasticsearch_dsl with django
import functools
from django.db import models
from django.db.models import query
from django.utils.functional import empty
from elasticsearch_dsl import Search
class BaseSearchIterable(query.BaseIterable):
def __init__(self, queryset):
self.search = getattr(queryset, '_search', None)
self.scores = {}
if isinstance(self.search, Search):
# TODO: use only _id and _score
# https://www.elastic.co/guide/en/elasticsearch/reference/5.2/search-request-stored-fields.html
for hit in self.search.source(()):
if hit._score is None:
self.scores = {}
break
self.scores[hit._id] = hit._score
if self.scores:
queryset = queryset.filter(pk__in=self.scores.keys())
super(BaseSearchIterable, self).__init__(queryset)
class SearchIterable(query.ModelIterable, BaseSearchIterable):
def __iter__(self):
iterable = super(SearchIterable, self).__iter__()
if self.scores:
iterable = iter(sorted(iterable, key=lambda obj: self.scores.get(obj.pk, -1)))
return iterable
class ValuesSearchIterable(query.ValuesIterable, BaseSearchIterable):
def __iter__(self):
iterable = super(ValuesSearchIterable, self).__iter__()
if self.scores:
iterable = iter(sorted(iterable, key=lambda obj: self.scores.get(obj['pk'], -1)))
return iterable
class ValuesListSearchIterable(query.ValuesListIterable, BaseSearchIterable):
def __iter__(self):
raise NotImplementedError('TODO: retrieve proper pk field')
class FlatValuesListSearchIterable(query.FlatValuesListIterable, BaseSearchIterable):
def __iter__(self):
iterable = super(FlatValuesListSearchIterable, self).__iter__()
if self.scores:
iterable = iter(sorted(iterable, key=lambda obj: self.scores.get(obj, -1)))
return iterable
class SearchQuerySet(models.QuerySet):
def __init__(self, *args, **kwargs):
super(SearchQuerySet, self).__init__(*args, **kwargs)
self._iterable_class = SearchIterable
self._search = None
self.es_mode = False
def search(self, search, **kwargs):
assert isinstance(search, Search)
return self._clone(_search=search, **kwargs)
def es(self, search=empty, **kwargs):
"""Enter Search attributes namespace"""
kwargs['es_mode'] = True
if search is not empty:
return self.search(search, **kwargs)
return self._clone(**kwargs)
def qs(self, **kwargs):
"""Exit Search attributes namespace"""
kwargs.setdefault('search', self._search)
return self._clone(es_mode=False, **kwargs)
def _clone(self, **kwargs):
for attr in {'_search', 'es_mode'}:
kwargs.setdefault(attr, getattr(self, attr))
clone = super(SearchQuerySet, self)._clone(**kwargs)
return clone
def _search_attr(self, name):
attr = getattr(self._search, name)
if callable(attr):
@functools.wraps(attr)
def call(*args, **kwargs):
ret = attr(*args, **kwargs)
if isinstance(ret, self._search.__class__):
return self._clone(search=ret)
return ret
return call
return attr
def __getattr__(self, name):
if not self._search:
return super(SearchQuerySet, self).__getattr__(name)
if name.startswith('qs_'):
return super(SearchQuerySet, self).__getattr__(name[3:])
if name.startswith('es_'):
return self._search_attr(name[3:])
if self.es_mode:
return self._search_attr(name)
return super(SearchQuerySet, self).__getattr__(name)
def __getitem__(self, item):
self._search = self._search[item]
return list(self)
def iterator(self):
return super(SearchQuerySet, self).iterator()
#
# Django overrides
#
def only(self, *fields):
assert not self._search or 'pk' in fields
return super(SearchQuerySet, self).only(*fields)
def defer(self, *fields):
assert not self._search or 'pk' not in fields
return super(SearchQuerySet, self).defer(*fields)
def values(self, *fields):
clone = self._values(*fields)
if self._search:
assert 'pk' in fields
clone._iterable_class = ValuesSearchIterable
else:
clone._iterable_class = query.ValuesIterable
return clone
def values_list(self, *fields, **kwargs):
flat = kwargs.pop('flat', False)
if kwargs:
raise TypeError('Unexpected keyword arguments to values_list: %s' % (list(kwargs),))
if flat and len(fields) > 1:
raise TypeError("'flat' is not valid when values_list is called with more than one field.")
clone = self._values(*fields)
if self._search:
assert 'pk' in fields
clone._iterable_class = FlatValuesListSearchIterable if flat else ValuesListSearchIterable
else:
clone._iterable_class = query.FlatValuesListIterable if flat else query.ValuesListIterable
return clone
def prefetch_init(self, lookup, queryset=None, to_attr=None):
# `prefetch_through` is the path we traverse to perform the prefetch.
self.prefetch_through = lookup
# `prefetch_to` is the path to the attribute that stores the result.
self.prefetch_to = lookup
if queryset is not None and not issubclass(queryset._iterable_class, query.ModelIterable):
raise ValueError('Prefetch querysets cannot use values().')
if to_attr:
self.prefetch_to = query.LOOKUP_SEP.join(lookup.split(query.LOOKUP_SEP)[:-1] + [to_attr])
self.queryset = queryset
self.to_attr = to_attr
models.Prefetch.__init__ = prefetch_init
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment