Created
February 16, 2017 14:12
-
-
Save nazarewk/6bcf8eefc03e7bc9f7fbd781a500aad4 to your computer and use it in GitHub Desktop.
Integrating elasticsearch_dsl with django
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import functools | |
from django.db import models | |
from django.db.models import query | |
from django.utils.functional import empty | |
from elasticsearch_dsl import Search | |
class BaseSearchIterable(query.BaseIterable): | |
def __init__(self, queryset): | |
self.search = getattr(queryset, '_search', None) | |
self.scores = {} | |
if isinstance(self.search, Search): | |
# TODO: use only _id and _score | |
# https://www.elastic.co/guide/en/elasticsearch/reference/5.2/search-request-stored-fields.html | |
for hit in self.search.source(()): | |
if hit._score is None: | |
self.scores = {} | |
break | |
self.scores[hit._id] = hit._score | |
if self.scores: | |
queryset = queryset.filter(pk__in=self.scores.keys()) | |
super(BaseSearchIterable, self).__init__(queryset) | |
class SearchIterable(query.ModelIterable, BaseSearchIterable): | |
def __iter__(self): | |
iterable = super(SearchIterable, self).__iter__() | |
if self.scores: | |
iterable = iter(sorted(iterable, key=lambda obj: self.scores.get(obj.pk, -1))) | |
return iterable | |
class ValuesSearchIterable(query.ValuesIterable, BaseSearchIterable): | |
def __iter__(self): | |
iterable = super(ValuesSearchIterable, self).__iter__() | |
if self.scores: | |
iterable = iter(sorted(iterable, key=lambda obj: self.scores.get(obj['pk'], -1))) | |
return iterable | |
class ValuesListSearchIterable(query.ValuesListIterable, BaseSearchIterable): | |
def __iter__(self): | |
raise NotImplementedError('TODO: retrieve proper pk field') | |
class FlatValuesListSearchIterable(query.FlatValuesListIterable, BaseSearchIterable): | |
def __iter__(self): | |
iterable = super(FlatValuesListSearchIterable, self).__iter__() | |
if self.scores: | |
iterable = iter(sorted(iterable, key=lambda obj: self.scores.get(obj, -1))) | |
return iterable | |
class SearchQuerySet(models.QuerySet): | |
def __init__(self, *args, **kwargs): | |
super(SearchQuerySet, self).__init__(*args, **kwargs) | |
self._iterable_class = SearchIterable | |
self._search = None | |
self.es_mode = False | |
def search(self, search, **kwargs): | |
assert isinstance(search, Search) | |
return self._clone(_search=search, **kwargs) | |
def es(self, search=empty, **kwargs): | |
"""Enter Search attributes namespace""" | |
kwargs['es_mode'] = True | |
if search is not empty: | |
return self.search(search, **kwargs) | |
return self._clone(**kwargs) | |
def qs(self, **kwargs): | |
"""Exit Search attributes namespace""" | |
kwargs.setdefault('search', self._search) | |
return self._clone(es_mode=False, **kwargs) | |
def _clone(self, **kwargs): | |
for attr in {'_search', 'es_mode'}: | |
kwargs.setdefault(attr, getattr(self, attr)) | |
clone = super(SearchQuerySet, self)._clone(**kwargs) | |
return clone | |
def _search_attr(self, name): | |
attr = getattr(self._search, name) | |
if callable(attr): | |
@functools.wraps(attr) | |
def call(*args, **kwargs): | |
ret = attr(*args, **kwargs) | |
if isinstance(ret, self._search.__class__): | |
return self._clone(search=ret) | |
return ret | |
return call | |
return attr | |
def __getattr__(self, name): | |
if not self._search: | |
return super(SearchQuerySet, self).__getattr__(name) | |
if name.startswith('qs_'): | |
return super(SearchQuerySet, self).__getattr__(name[3:]) | |
if name.startswith('es_'): | |
return self._search_attr(name[3:]) | |
if self.es_mode: | |
return self._search_attr(name) | |
return super(SearchQuerySet, self).__getattr__(name) | |
def __getitem__(self, item): | |
self._search = self._search[item] | |
return list(self) | |
def iterator(self): | |
return super(SearchQuerySet, self).iterator() | |
# | |
# Django overrides | |
# | |
def only(self, *fields): | |
assert not self._search or 'pk' in fields | |
return super(SearchQuerySet, self).only(*fields) | |
def defer(self, *fields): | |
assert not self._search or 'pk' not in fields | |
return super(SearchQuerySet, self).defer(*fields) | |
def values(self, *fields): | |
clone = self._values(*fields) | |
if self._search: | |
assert 'pk' in fields | |
clone._iterable_class = ValuesSearchIterable | |
else: | |
clone._iterable_class = query.ValuesIterable | |
return clone | |
def values_list(self, *fields, **kwargs): | |
flat = kwargs.pop('flat', False) | |
if kwargs: | |
raise TypeError('Unexpected keyword arguments to values_list: %s' % (list(kwargs),)) | |
if flat and len(fields) > 1: | |
raise TypeError("'flat' is not valid when values_list is called with more than one field.") | |
clone = self._values(*fields) | |
if self._search: | |
assert 'pk' in fields | |
clone._iterable_class = FlatValuesListSearchIterable if flat else ValuesListSearchIterable | |
else: | |
clone._iterable_class = query.FlatValuesListIterable if flat else query.ValuesListIterable | |
return clone | |
def prefetch_init(self, lookup, queryset=None, to_attr=None): | |
# `prefetch_through` is the path we traverse to perform the prefetch. | |
self.prefetch_through = lookup | |
# `prefetch_to` is the path to the attribute that stores the result. | |
self.prefetch_to = lookup | |
if queryset is not None and not issubclass(queryset._iterable_class, query.ModelIterable): | |
raise ValueError('Prefetch querysets cannot use values().') | |
if to_attr: | |
self.prefetch_to = query.LOOKUP_SEP.join(lookup.split(query.LOOKUP_SEP)[:-1] + [to_attr]) | |
self.queryset = queryset | |
self.to_attr = to_attr | |
models.Prefetch.__init__ = prefetch_init |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment