Skip to content

Instantly share code, notes, and snippets.

@kimchy
Created December 27, 2011 14:08
Show Gist options
  • Save kimchy/1523762 to your computer and use it in GitHub Desktop.
Save kimchy/1523762 to your computer and use it in GitHub Desktop.
Index: src/main/java/org/elasticsearch/index/query/CustomFiltersScoreQueryParser.java
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
--- src/main/java/org/elasticsearch/index/query/CustomFiltersScoreQueryParser.java (date 1324946452000)
+++ src/main/java/org/elasticsearch/index/query/CustomFiltersScoreQueryParser.java (revision )
@@ -21,7 +21,6 @@
import org.apache.lucene.search.Filter;
import org.apache.lucene.search.Query;
-import org.elasticsearch.ElasticSearchIllegalStateException;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.inject.Inject;
import org.elasticsearch.common.lucene.search.function.BoostScoreFunction;
@@ -96,12 +95,10 @@
throw new QueryParsingException(parseContext.index(), "[custom_filters_score] requires 'query' field");
}
+ // we only need context current context for scripting, its ok not to have it for boost
SearchContext context = SearchContext.current();
- if (context == null) {
- throw new ElasticSearchIllegalStateException("No search context on going...");
- }
- FiltersFunctionScoreQuery.FilterScoreGroup filterGroups = buildFunctionGroups(context, scriptLang, vars, rootGroup);
+ FiltersFunctionScoreQuery.FilterScoreGroup filterGroups = rootGroup.buildFilterItem(context, scriptLang, vars);
FiltersFunctionScoreQuery functionScoreQuery = new FiltersFunctionScoreQuery(query, filterGroups);
functionScoreQuery.setBoost(boost);
@@ -148,7 +145,6 @@
parseFilterLevel(parseContext, parser, subGroup);
currentGroup.add(subGroup);
possibleNonNestedObjectStarted = false;
- continue;
} else if (token == XContentParser.Token.START_OBJECT) {
possibleNonNestedObjectStarted = true;
if ("filter".equals(currentFieldName)) {
@@ -180,45 +176,36 @@
}
}
- private FiltersFunctionScoreQuery.FilterScoreGroup buildFunctionGroups(final SearchContext context, final String scriptLang, final Map<String, Object> vars, final NestedGroup currentNesting) {
- ArrayList<FiltersFunctionScoreQuery.FilterItem> filterItems = new ArrayList<FiltersFunctionScoreQuery.FilterItem>(currentNesting.size());
- for (NestedItem nestedItem : currentNesting.items) {
- if (nestedItem instanceof NestedGroup) {
- filterItems.add(buildFunctionGroups(context, scriptLang, vars, (NestedGroup) nestedItem));
- } else if (nestedItem instanceof NestedFilter) {
- NestedFilter filter = (NestedFilter) nestedItem;
- ScoreFunction scoreFunction;
- String script = filter.script;
- if (script != null) {
- SearchScript searchScript = context.scriptService().search(context.lookup(), scriptLang, script, vars);
- scoreFunction = new CustomScoreQueryParser.ScriptScoreFunction(script, vars, searchScript);
- } else {
- scoreFunction = new BoostScoreFunction(filter.boost);
- }
- filterItems.add(new FiltersFunctionScoreQuery.FilterFunction(filter.filter, scoreFunction));
- }
- }
- return new FiltersFunctionScoreQuery.FilterScoreGroup(currentNesting.scoreMode, filterItems.toArray(new FiltersFunctionScoreQuery.FilterItem[filterItems.size()]));
- }
+ abstract class NestedItem<T extends FiltersFunctionScoreQuery.FilterItem> {
-
- private abstract class NestedItem {
-
+ public abstract T buildFilterItem(final SearchContext context, final String scriptLang, final Map<String, Object> vars);
}
- private class NestedFilter extends NestedItem {
+ class NestedFilter extends NestedItem<FiltersFunctionScoreQuery.FilterFunction> {
Filter filter;
String script;
- Float boost;
+ float boost;
- private NestedFilter(Filter filter, String script, Float boost) {
+ NestedFilter(Filter filter, String script, float boost) {
this.filter = filter;
this.script = script;
this.boost = boost;
}
+
+ @Override
+ public FiltersFunctionScoreQuery.FilterFunction buildFilterItem(SearchContext context, String scriptLang, Map<String, Object> vars) {
+ ScoreFunction scoreFunction;
+ if (script != null) {
+ SearchScript searchScript = context.scriptService().search(context.lookup(), scriptLang, script, vars);
+ scoreFunction = new CustomScoreQueryParser.ScriptScoreFunction(script, vars, searchScript);
+ } else {
+ scoreFunction = new BoostScoreFunction(boost);
- }
+ }
+ return new FiltersFunctionScoreQuery.FilterFunction(filter, scoreFunction);
+ }
+ }
- private class NestedGroup extends NestedItem {
+ class NestedGroup extends NestedItem<FiltersFunctionScoreQuery.FilterScoreGroup> {
NestedGroup parent;
FiltersFunctionScoreQuery.ScoreMode scoreMode;
ArrayList<NestedItem> items = new ArrayList<NestedItem>();
@@ -241,5 +228,13 @@
return items.size();
}
+ @Override
+ public FiltersFunctionScoreQuery.FilterScoreGroup buildFilterItem(SearchContext context, String scriptLang, Map<String, Object> vars) {
+ FiltersFunctionScoreQuery.FilterItem[] filterItems = new FiltersFunctionScoreQuery.FilterItem[items.size()];
+ for (int i = 0; i < items.size(); i++) {
+ filterItems[i] = items.get(i).buildFilterItem(context, scriptLang, vars);
+ }
+ return new FiltersFunctionScoreQuery.FilterScoreGroup(scoreMode, filterItems);
+ }
}
}
\ No newline at end of file
Index: src/main/java/org/elasticsearch/index/query/CustomScoreQueryParser.java
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
--- src/main/java/org/elasticsearch/index/query/CustomScoreQueryParser.java (date 1324946452000)
+++ src/main/java/org/elasticsearch/index/query/CustomScoreQueryParser.java (revision )
@@ -125,6 +125,12 @@
}
@Override
+ public float factor(int docId) {
+ script.setNextDocId(docId);
+ return script.runAsFloat();
+ }
+
+ @Override
public Explanation explain(int docId, Explanation subQueryExpl) {
float score = score(docId, subQueryExpl.getValue());
Explanation exp = new Explanation(score, "script score function: product of:");
\ No newline at end of file
Index: src/test/java/org/elasticsearch/test/integration/search/customscore/CustomScoreSearchTests.java
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
--- src/test/java/org/elasticsearch/test/integration/search/customscore/CustomScoreSearchTests.java (date 1324946452000)
+++ src/test/java/org/elasticsearch/test/integration/search/customscore/CustomScoreSearchTests.java (revision )
@@ -31,9 +31,7 @@
import java.util.Arrays;
-import static org.elasticsearch.client.Requests.indexRequest;
-import static org.elasticsearch.client.Requests.refreshRequest;
-import static org.elasticsearch.client.Requests.searchRequest;
+import static org.elasticsearch.client.Requests.*;
import static org.elasticsearch.common.settings.ImmutableSettings.settingsBuilder;
import static org.elasticsearch.common.xcontent.XContentFactory.jsonBuilder;
import static org.elasticsearch.index.query.FilterBuilders.termFilter;
@@ -172,8 +170,8 @@
SearchResponse searchResponse = client.prepareSearch("test")
.setQuery(customFiltersScoreQuery(allBoosted2Query)
- .add(termFilter("field", "value4"), "_score * 2")
- .add(termFilter("field", "value2"), "_score * 3"))
+ .add(termFilter("field", "value4"), "2")
+ .add(termFilter("field", "value2"), "3"))
.setExplain(true)
.execute().actionGet();
@@ -286,7 +284,7 @@
assertThat(Arrays.toString(searchResponse.shardFailures()), searchResponse.failedShards(), equalTo(0));
assertThat(searchResponse.hits().totalHits(), equalTo(4l));
assertThat(searchResponse.hits().getAt(0).id(), equalTo("1"));
- assertThat(searchResponse.hits().getAt(0).score(), equalTo(60.0f));
+ assertThat(searchResponse.hits().getAt(0).score(), equalTo(30.0f));
logger.info("--> Hit[0] {} Explanation {}", searchResponse.hits().getAt(0).id(), searchResponse.hits().getAt(0).explanation());
assertThat(searchResponse.hits().getAt(1).id(), equalTo("3"));
assertThat(searchResponse.hits().getAt(1).score(), equalTo(10.0f));
\ No newline at end of file
Index: src/main/java/org/elasticsearch/common/lucene/search/function/ScoreFunction.java
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
--- src/main/java/org/elasticsearch/common/lucene/search/function/ScoreFunction.java (date 1324946452000)
+++ src/main/java/org/elasticsearch/common/lucene/search/function/ScoreFunction.java (revision )
@@ -31,5 +31,7 @@
float score(int docId, float subQueryScore);
+ float factor(int docId);
+
Explanation explain(int docId, Explanation subQueryExpl);
}
Index: src/main/java/org/elasticsearch/common/lucene/search/function/FiltersFunctionScoreQuery.java
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
--- src/main/java/org/elasticsearch/common/lucene/search/function/FiltersFunctionScoreQuery.java (date 1324946452000)
+++ src/main/java/org/elasticsearch/common/lucene/search/function/FiltersFunctionScoreQuery.java (revision )
@@ -27,7 +27,9 @@
import org.elasticsearch.common.lucene.docset.DocSets;
import java.io.IOException;
-import java.util.*;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Set;
/**
* A query that allows for a pluggable boost function / filter. If it matches the filter, it will
@@ -36,6 +38,13 @@
public class FiltersFunctionScoreQuery extends Query {
public static abstract class FilterItem {
+
+ public abstract void setNextReader(IndexReader reader) throws IOException;
+
+ /**
+ * Returns {@link Float#NaN} if no scoring / matching happened.
+ */
+ public abstract float factor(int docId);
}
public static class FilterScoreGroup extends FilterItem {
@@ -48,6 +57,71 @@
}
@Override
+ public void setNextReader(IndexReader reader) throws IOException {
+ for (FilterItem filterItem : filterItems) {
+ filterItem.setNextReader(reader);
+ }
+ }
+
+ @Override
+ public float factor(int docId) {
+ switch (scoreMode) {
+ case First:
+ for (FilterItem filterItem : filterItems) {
+ float factor = filterItem.factor(docId);
+ if (!Float.isNaN(factor)) {
+ return factor;
+ }
+ }
+ return Float.NaN;
+ case Max:
+ float maxFactor = Float.NEGATIVE_INFINITY;
+ for (FilterItem filterItem : filterItems) {
+ float factor = filterItem.factor(docId);
+ if (!Float.isNaN(factor)) {
+ maxFactor = Math.max(maxFactor, factor);
+ }
+ }
+ return maxFactor == Float.NEGATIVE_INFINITY ? Float.NaN : maxFactor;
+ case Min:
+ float minFactor = Float.POSITIVE_INFINITY;
+ for (FilterItem filterItem : filterItems) {
+ float factor = filterItem.factor(docId);
+ if (!Float.isNaN(factor)) {
+ minFactor = Math.min(minFactor, factor);
+ }
+ }
+ return minFactor == Float.POSITIVE_INFINITY ? Float.NaN : minFactor;
+ case Avg:
+ case Total:
+ case Multiply:
+ float totalFactor = 0;
+ float multFactor = 1;
+ int count = 0;
+ for (FilterItem filterItem : filterItems) {
+ float factor = filterItem.factor(docId);
+ if (!Float.isNaN(factor)) {
+ count++;
+ totalFactor += factor;
+ multFactor *= factor;
+ }
+ }
+ if (count == 0) {
+ return Float.NaN;
+ }
+ if (scoreMode == ScoreMode.Total) {
+ return totalFactor;
+ }
+ if (scoreMode == ScoreMode.Multiply) {
+ return multFactor;
+ }
+ // Avg
+ return totalFactor / count;
+ }
+ return Float.NaN;
+ }
+
+ @Override
public String toString() {
StringBuilder sb = new StringBuilder();
sb.append("filter score group (").append(scoreMode.toString()).append(", children: [");
@@ -77,6 +151,7 @@
public static class FilterFunction extends FilterItem {
public final Filter filter;
public final ScoreFunction function;
+ DocSet docSet;
public FilterFunction(Filter filter, ScoreFunction function) {
this.filter = filter;
@@ -84,6 +159,20 @@
}
@Override
+ public void setNextReader(IndexReader reader) throws IOException {
+ function.setNextReader(reader);
+ docSet = DocSets.convert(reader, filter.getDocIdSet(reader));
+ }
+
+ @Override
+ public float factor(int docId) {
+ if (!docSet.get(docId)) {
+ return Float.NaN;
+ }
+ return function.factor(docId);
+ }
+
+ @Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
@@ -114,7 +203,6 @@
public static enum ScoreMode {First, Avg, Max, Total, Min, Multiply}
Query subQuery;
- Map<FilterFunction, DocSet> docSets = new HashMap<FilterFunction, DocSet>(); // TODO: something better than a map?
FilterScoreGroup filters;
public FiltersFunctionScoreQuery(Query subQuery, ScoreMode scoreMode, FilterFunction[] filterFunctions) {
@@ -132,12 +220,6 @@
return subQuery;
}
- // 26-Dec-2011 NOTE: breaking API change from getFilterFunctions() : FilterFunction -- which was unused method in
- // ES code base, but not necessarily unused for users of ES
- public FilterScoreGroup getFilterFunctions() {
- return filters;
- }
-
@Override
public Query rewrite(IndexReader reader) throws IOException {
Query newQ = subQuery.rewrite(reader);
@@ -195,24 +277,10 @@
if (subQueryScorer == null) {
return null;
}
- initFilterFunctionDocSets(reader, filters);
- return new CustomBoostFactorScorer(getSimilarity(searcher), this, subQueryScorer, filters, docSets);
+ filters.setNextReader(reader);
+ return new CustomBoostFactorScorer(getSimilarity(searcher), this, subQueryScorer, filters);
}
- private void initFilterFunctionDocSets(IndexReader reader, FilterScoreGroup group) throws IOException {
- for (FilterItem item : group.filterItems) {
- if (item instanceof FilterScoreGroup) {
- initFilterFunctionDocSets(reader, (FilterScoreGroup) item);
- } else if (item instanceof FilterFunction) {
- FilterFunction filterFunction = (FilterFunction) item;
- if (!docSets.containsKey(filterFunction)) {
- filterFunction.function.setNextReader(reader);
- docSets.put(filterFunction, DocSets.convert(reader, filterFunction.filter.getDocIdSet(reader)));
- }
- }
- }
- }
-
@Override
public Explanation explain(IndexReader reader, int doc) throws IOException {
Explanation subQueryExpl = subQueryWeight.explain(reader, doc);
@@ -344,16 +412,12 @@
private final float subQueryWeight;
private final Scorer scorer;
private final FilterScoreGroup filters;
- private final Map<FilterFunction, DocSet> docSets;
-
- private CustomBoostFactorScorer(Similarity similarity, CustomBoostFactorWeight w, Scorer scorer,
- FilterScoreGroup filters, Map<FilterFunction, DocSet> docSets) throws IOException {
+ private CustomBoostFactorScorer(Similarity similarity, CustomBoostFactorWeight w, Scorer scorer, FilterScoreGroup filters) throws IOException {
super(similarity);
this.subQueryWeight = w.getValue();
this.scorer = scorer;
this.filters = filters;
- this.docSets = docSets;
}
@Override
@@ -373,101 +437,12 @@
@Override
public float score() throws IOException {
- int docId = scorer.docID();
- float score = scorer.score();
- score = score(filters, docId, score);
- return subQueryWeight * score;
+ float queryScore = scorer.score();
+ float factor = filters.factor(scorer.docID());
+ if (!Float.isNaN(factor)) {
+ queryScore *= factor;
- }
+ }
-
- private float score(FilterScoreGroup group, int docId, float score) {
- float newScore = scoreOnlyMatchingFilters(group, docId, score);
- return newScore != Float.NEGATIVE_INFINITY ? newScore : score;
- }
-
- private float scoreOnlyMatchingFilters(FilterScoreGroup group, int docId, float subqueryScore) {
- if (group.scoreMode == ScoreMode.First) {
- for (FilterItem item : group.filterItems) {
- if (item instanceof FilterScoreGroup) {
- float tempScore = scoreOnlyMatchingFilters((FilterScoreGroup) item, docId, subqueryScore);
- if (tempScore != Float.NEGATIVE_INFINITY) {
- return tempScore;
- }
- } else {
- FilterFunction filterFunction = (FilterFunction) item;
- DocSet docSet = docSets.get(filterFunction);
- if (docSet != null && docSet.get(docId)) {
- return filterFunction.function.score(docId, subqueryScore);
- }
- }
- }
- } else if (group.scoreMode == ScoreMode.Max) {
- float maxScore = Float.NEGATIVE_INFINITY;
- for (FilterItem item : group.filterItems) {
- if (item instanceof FilterScoreGroup) {
- maxScore = Math.max(scoreOnlyMatchingFilters((FilterScoreGroup) item, docId, subqueryScore), maxScore);
- } else {
- FilterFunction filterFunction = (FilterFunction) item;
- DocSet docSet = docSets.get(filterFunction);
- if (docSet != null && docSet.get(docId)) {
- maxScore = Math.max(filterFunction.function.score(docId, subqueryScore), maxScore);
- }
- }
- }
- if (maxScore != Float.NEGATIVE_INFINITY) {
- return maxScore;
- }
- } else if (group.scoreMode == ScoreMode.Min) {
- float minScore = Float.POSITIVE_INFINITY;
- for (FilterItem item : group.filterItems) {
- if (item instanceof FilterScoreGroup) {
- float tempScore = scoreOnlyMatchingFilters((FilterScoreGroup) item, docId, subqueryScore);
- if (tempScore != Float.NEGATIVE_INFINITY) {
- minScore = Math.min(tempScore, minScore);
- }
- } else {
- FilterFunction filterFunction = (FilterFunction) item;
- DocSet docSet = docSets.get(filterFunction);
- if (docSet != null && docSet.get(docId)) {
- minScore = Math.min(filterFunction.function.score(docId, subqueryScore), minScore);
- }
- }
- }
- if (minScore != Float.POSITIVE_INFINITY) {
- return minScore;
- }
- } else { // Avg / Total / multiply
- float totalScore = 0.0f;
- float multiplicativeScore = 1.0f;
- int count = 0;
- for (FilterItem item : group.filterItems) {
- if (item instanceof FilterScoreGroup) {
- float tempScore = scoreOnlyMatchingFilters((FilterScoreGroup) item, docId, subqueryScore);
- if (tempScore != Float.NEGATIVE_INFINITY) {
- totalScore += tempScore;
- multiplicativeScore *= tempScore;
- count++;
- }
- } else {
- FilterFunction filterFunction = (FilterFunction) item;
- DocSet docSet = docSets.get(filterFunction);
- if (docSet != null && docSet.get(docId)) {
- float tempScore = filterFunction.function.score(docId, subqueryScore);
- totalScore += tempScore;
- multiplicativeScore *= tempScore;
- count++;
- }
- }
- }
- if (count != 0) {
- if (group.scoreMode == ScoreMode.Avg) {
- return totalScore / count;
- } else if (group.scoreMode == ScoreMode.Multiply) {
- return multiplicativeScore;
- }
- return totalScore;
- }
- }
- return Float.NEGATIVE_INFINITY;
+ return subQueryWeight * queryScore;
}
}
Index: src/main/java/org/elasticsearch/common/lucene/search/function/BoostScoreFunction.java
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
--- src/main/java/org/elasticsearch/common/lucene/search/function/BoostScoreFunction.java (date 1324946452000)
+++ src/main/java/org/elasticsearch/common/lucene/search/function/BoostScoreFunction.java (revision )
@@ -49,6 +49,11 @@
}
@Override
+ public float factor(int docId) {
+ return boost;
+ }
+
+ @Override
public Explanation explain(int docId, Explanation subQueryExpl) {
Explanation exp = new Explanation(boost * subQueryExpl.getValue(), "static boost function: product of:");
exp.addDetail(subQueryExpl);
@@ -71,5 +76,10 @@
@Override
public int hashCode() {
return (boost != +0.0f ? Float.floatToIntBits(boost) : 0);
+ }
+
+ @Override
+ public String toString() {
+ return "boost(" + boost + ")";
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment