Created
December 27, 2011 14:08
-
-
Save kimchy/1523762 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Index: src/main/java/org/elasticsearch/index/query/CustomFiltersScoreQueryParser.java | |
IDEA additional info: | |
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP | |
<+>UTF-8 | |
=================================================================== | |
--- src/main/java/org/elasticsearch/index/query/CustomFiltersScoreQueryParser.java (date 1324946452000) | |
+++ src/main/java/org/elasticsearch/index/query/CustomFiltersScoreQueryParser.java (revision ) | |
@@ -21,7 +21,6 @@ | |
import org.apache.lucene.search.Filter; | |
import org.apache.lucene.search.Query; | |
-import org.elasticsearch.ElasticSearchIllegalStateException; | |
import org.elasticsearch.common.Strings; | |
import org.elasticsearch.common.inject.Inject; | |
import org.elasticsearch.common.lucene.search.function.BoostScoreFunction; | |
@@ -96,12 +95,10 @@ | |
throw new QueryParsingException(parseContext.index(), "[custom_filters_score] requires 'query' field"); | |
} | |
+ // we only need context current context for scripting, its ok not to have it for boost | |
SearchContext context = SearchContext.current(); | |
- if (context == null) { | |
- throw new ElasticSearchIllegalStateException("No search context on going..."); | |
- } | |
- FiltersFunctionScoreQuery.FilterScoreGroup filterGroups = buildFunctionGroups(context, scriptLang, vars, rootGroup); | |
+ FiltersFunctionScoreQuery.FilterScoreGroup filterGroups = rootGroup.buildFilterItem(context, scriptLang, vars); | |
FiltersFunctionScoreQuery functionScoreQuery = new FiltersFunctionScoreQuery(query, filterGroups); | |
functionScoreQuery.setBoost(boost); | |
@@ -148,7 +145,6 @@ | |
parseFilterLevel(parseContext, parser, subGroup); | |
currentGroup.add(subGroup); | |
possibleNonNestedObjectStarted = false; | |
- continue; | |
} else if (token == XContentParser.Token.START_OBJECT) { | |
possibleNonNestedObjectStarted = true; | |
if ("filter".equals(currentFieldName)) { | |
@@ -180,45 +176,36 @@ | |
} | |
} | |
- private FiltersFunctionScoreQuery.FilterScoreGroup buildFunctionGroups(final SearchContext context, final String scriptLang, final Map<String, Object> vars, final NestedGroup currentNesting) { | |
- ArrayList<FiltersFunctionScoreQuery.FilterItem> filterItems = new ArrayList<FiltersFunctionScoreQuery.FilterItem>(currentNesting.size()); | |
- for (NestedItem nestedItem : currentNesting.items) { | |
- if (nestedItem instanceof NestedGroup) { | |
- filterItems.add(buildFunctionGroups(context, scriptLang, vars, (NestedGroup) nestedItem)); | |
- } else if (nestedItem instanceof NestedFilter) { | |
- NestedFilter filter = (NestedFilter) nestedItem; | |
- ScoreFunction scoreFunction; | |
- String script = filter.script; | |
- if (script != null) { | |
- SearchScript searchScript = context.scriptService().search(context.lookup(), scriptLang, script, vars); | |
- scoreFunction = new CustomScoreQueryParser.ScriptScoreFunction(script, vars, searchScript); | |
- } else { | |
- scoreFunction = new BoostScoreFunction(filter.boost); | |
- } | |
- filterItems.add(new FiltersFunctionScoreQuery.FilterFunction(filter.filter, scoreFunction)); | |
- } | |
- } | |
- return new FiltersFunctionScoreQuery.FilterScoreGroup(currentNesting.scoreMode, filterItems.toArray(new FiltersFunctionScoreQuery.FilterItem[filterItems.size()])); | |
- } | |
+ abstract class NestedItem<T extends FiltersFunctionScoreQuery.FilterItem> { | |
- | |
- private abstract class NestedItem { | |
- | |
+ public abstract T buildFilterItem(final SearchContext context, final String scriptLang, final Map<String, Object> vars); | |
} | |
- private class NestedFilter extends NestedItem { | |
+ class NestedFilter extends NestedItem<FiltersFunctionScoreQuery.FilterFunction> { | |
Filter filter; | |
String script; | |
- Float boost; | |
+ float boost; | |
- private NestedFilter(Filter filter, String script, Float boost) { | |
+ NestedFilter(Filter filter, String script, float boost) { | |
this.filter = filter; | |
this.script = script; | |
this.boost = boost; | |
} | |
+ | |
+ @Override | |
+ public FiltersFunctionScoreQuery.FilterFunction buildFilterItem(SearchContext context, String scriptLang, Map<String, Object> vars) { | |
+ ScoreFunction scoreFunction; | |
+ if (script != null) { | |
+ SearchScript searchScript = context.scriptService().search(context.lookup(), scriptLang, script, vars); | |
+ scoreFunction = new CustomScoreQueryParser.ScriptScoreFunction(script, vars, searchScript); | |
+ } else { | |
+ scoreFunction = new BoostScoreFunction(boost); | |
- } | |
+ } | |
+ return new FiltersFunctionScoreQuery.FilterFunction(filter, scoreFunction); | |
+ } | |
+ } | |
- private class NestedGroup extends NestedItem { | |
+ class NestedGroup extends NestedItem<FiltersFunctionScoreQuery.FilterScoreGroup> { | |
NestedGroup parent; | |
FiltersFunctionScoreQuery.ScoreMode scoreMode; | |
ArrayList<NestedItem> items = new ArrayList<NestedItem>(); | |
@@ -241,5 +228,13 @@ | |
return items.size(); | |
} | |
+ @Override | |
+ public FiltersFunctionScoreQuery.FilterScoreGroup buildFilterItem(SearchContext context, String scriptLang, Map<String, Object> vars) { | |
+ FiltersFunctionScoreQuery.FilterItem[] filterItems = new FiltersFunctionScoreQuery.FilterItem[items.size()]; | |
+ for (int i = 0; i < items.size(); i++) { | |
+ filterItems[i] = items.get(i).buildFilterItem(context, scriptLang, vars); | |
+ } | |
+ return new FiltersFunctionScoreQuery.FilterScoreGroup(scoreMode, filterItems); | |
+ } | |
} | |
} | |
\ No newline at end of file | |
Index: src/main/java/org/elasticsearch/index/query/CustomScoreQueryParser.java | |
IDEA additional info: | |
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP | |
<+>UTF-8 | |
=================================================================== | |
--- src/main/java/org/elasticsearch/index/query/CustomScoreQueryParser.java (date 1324946452000) | |
+++ src/main/java/org/elasticsearch/index/query/CustomScoreQueryParser.java (revision ) | |
@@ -125,6 +125,12 @@ | |
} | |
@Override | |
+ public float factor(int docId) { | |
+ script.setNextDocId(docId); | |
+ return script.runAsFloat(); | |
+ } | |
+ | |
+ @Override | |
public Explanation explain(int docId, Explanation subQueryExpl) { | |
float score = score(docId, subQueryExpl.getValue()); | |
Explanation exp = new Explanation(score, "script score function: product of:"); | |
\ No newline at end of file | |
Index: src/test/java/org/elasticsearch/test/integration/search/customscore/CustomScoreSearchTests.java | |
IDEA additional info: | |
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP | |
<+>UTF-8 | |
=================================================================== | |
--- src/test/java/org/elasticsearch/test/integration/search/customscore/CustomScoreSearchTests.java (date 1324946452000) | |
+++ src/test/java/org/elasticsearch/test/integration/search/customscore/CustomScoreSearchTests.java (revision ) | |
@@ -31,9 +31,7 @@ | |
import java.util.Arrays; | |
-import static org.elasticsearch.client.Requests.indexRequest; | |
-import static org.elasticsearch.client.Requests.refreshRequest; | |
-import static org.elasticsearch.client.Requests.searchRequest; | |
+import static org.elasticsearch.client.Requests.*; | |
import static org.elasticsearch.common.settings.ImmutableSettings.settingsBuilder; | |
import static org.elasticsearch.common.xcontent.XContentFactory.jsonBuilder; | |
import static org.elasticsearch.index.query.FilterBuilders.termFilter; | |
@@ -172,8 +170,8 @@ | |
SearchResponse searchResponse = client.prepareSearch("test") | |
.setQuery(customFiltersScoreQuery(allBoosted2Query) | |
- .add(termFilter("field", "value4"), "_score * 2") | |
- .add(termFilter("field", "value2"), "_score * 3")) | |
+ .add(termFilter("field", "value4"), "2") | |
+ .add(termFilter("field", "value2"), "3")) | |
.setExplain(true) | |
.execute().actionGet(); | |
@@ -286,7 +284,7 @@ | |
assertThat(Arrays.toString(searchResponse.shardFailures()), searchResponse.failedShards(), equalTo(0)); | |
assertThat(searchResponse.hits().totalHits(), equalTo(4l)); | |
assertThat(searchResponse.hits().getAt(0).id(), equalTo("1")); | |
- assertThat(searchResponse.hits().getAt(0).score(), equalTo(60.0f)); | |
+ assertThat(searchResponse.hits().getAt(0).score(), equalTo(30.0f)); | |
logger.info("--> Hit[0] {} Explanation {}", searchResponse.hits().getAt(0).id(), searchResponse.hits().getAt(0).explanation()); | |
assertThat(searchResponse.hits().getAt(1).id(), equalTo("3")); | |
assertThat(searchResponse.hits().getAt(1).score(), equalTo(10.0f)); | |
\ No newline at end of file | |
Index: src/main/java/org/elasticsearch/common/lucene/search/function/ScoreFunction.java | |
IDEA additional info: | |
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP | |
<+>UTF-8 | |
=================================================================== | |
--- src/main/java/org/elasticsearch/common/lucene/search/function/ScoreFunction.java (date 1324946452000) | |
+++ src/main/java/org/elasticsearch/common/lucene/search/function/ScoreFunction.java (revision ) | |
@@ -31,5 +31,7 @@ | |
float score(int docId, float subQueryScore); | |
+ float factor(int docId); | |
+ | |
Explanation explain(int docId, Explanation subQueryExpl); | |
} | |
Index: src/main/java/org/elasticsearch/common/lucene/search/function/FiltersFunctionScoreQuery.java | |
IDEA additional info: | |
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP | |
<+>UTF-8 | |
=================================================================== | |
--- src/main/java/org/elasticsearch/common/lucene/search/function/FiltersFunctionScoreQuery.java (date 1324946452000) | |
+++ src/main/java/org/elasticsearch/common/lucene/search/function/FiltersFunctionScoreQuery.java (revision ) | |
@@ -27,7 +27,9 @@ | |
import org.elasticsearch.common.lucene.docset.DocSets; | |
import java.io.IOException; | |
-import java.util.*; | |
+import java.util.ArrayList; | |
+import java.util.Arrays; | |
+import java.util.Set; | |
/** | |
* A query that allows for a pluggable boost function / filter. If it matches the filter, it will | |
@@ -36,6 +38,13 @@ | |
public class FiltersFunctionScoreQuery extends Query { | |
public static abstract class FilterItem { | |
+ | |
+ public abstract void setNextReader(IndexReader reader) throws IOException; | |
+ | |
+ /** | |
+ * Returns {@link Float#NaN} if no scoring / matching happened. | |
+ */ | |
+ public abstract float factor(int docId); | |
} | |
public static class FilterScoreGroup extends FilterItem { | |
@@ -48,6 +57,71 @@ | |
} | |
@Override | |
+ public void setNextReader(IndexReader reader) throws IOException { | |
+ for (FilterItem filterItem : filterItems) { | |
+ filterItem.setNextReader(reader); | |
+ } | |
+ } | |
+ | |
+ @Override | |
+ public float factor(int docId) { | |
+ switch (scoreMode) { | |
+ case First: | |
+ for (FilterItem filterItem : filterItems) { | |
+ float factor = filterItem.factor(docId); | |
+ if (!Float.isNaN(factor)) { | |
+ return factor; | |
+ } | |
+ } | |
+ return Float.NaN; | |
+ case Max: | |
+ float maxFactor = Float.NEGATIVE_INFINITY; | |
+ for (FilterItem filterItem : filterItems) { | |
+ float factor = filterItem.factor(docId); | |
+ if (!Float.isNaN(factor)) { | |
+ maxFactor = Math.max(maxFactor, factor); | |
+ } | |
+ } | |
+ return maxFactor == Float.NEGATIVE_INFINITY ? Float.NaN : maxFactor; | |
+ case Min: | |
+ float minFactor = Float.POSITIVE_INFINITY; | |
+ for (FilterItem filterItem : filterItems) { | |
+ float factor = filterItem.factor(docId); | |
+ if (!Float.isNaN(factor)) { | |
+ minFactor = Math.min(minFactor, factor); | |
+ } | |
+ } | |
+ return minFactor == Float.POSITIVE_INFINITY ? Float.NaN : minFactor; | |
+ case Avg: | |
+ case Total: | |
+ case Multiply: | |
+ float totalFactor = 0; | |
+ float multFactor = 1; | |
+ int count = 0; | |
+ for (FilterItem filterItem : filterItems) { | |
+ float factor = filterItem.factor(docId); | |
+ if (!Float.isNaN(factor)) { | |
+ count++; | |
+ totalFactor += factor; | |
+ multFactor *= factor; | |
+ } | |
+ } | |
+ if (count == 0) { | |
+ return Float.NaN; | |
+ } | |
+ if (scoreMode == ScoreMode.Total) { | |
+ return totalFactor; | |
+ } | |
+ if (scoreMode == ScoreMode.Multiply) { | |
+ return multFactor; | |
+ } | |
+ // Avg | |
+ return totalFactor / count; | |
+ } | |
+ return Float.NaN; | |
+ } | |
+ | |
+ @Override | |
public String toString() { | |
StringBuilder sb = new StringBuilder(); | |
sb.append("filter score group (").append(scoreMode.toString()).append(", children: ["); | |
@@ -77,6 +151,7 @@ | |
public static class FilterFunction extends FilterItem { | |
public final Filter filter; | |
public final ScoreFunction function; | |
+ DocSet docSet; | |
public FilterFunction(Filter filter, ScoreFunction function) { | |
this.filter = filter; | |
@@ -84,6 +159,20 @@ | |
} | |
@Override | |
+ public void setNextReader(IndexReader reader) throws IOException { | |
+ function.setNextReader(reader); | |
+ docSet = DocSets.convert(reader, filter.getDocIdSet(reader)); | |
+ } | |
+ | |
+ @Override | |
+ public float factor(int docId) { | |
+ if (!docSet.get(docId)) { | |
+ return Float.NaN; | |
+ } | |
+ return function.factor(docId); | |
+ } | |
+ | |
+ @Override | |
public boolean equals(Object o) { | |
if (this == o) return true; | |
if (o == null || getClass() != o.getClass()) return false; | |
@@ -114,7 +203,6 @@ | |
public static enum ScoreMode {First, Avg, Max, Total, Min, Multiply} | |
Query subQuery; | |
- Map<FilterFunction, DocSet> docSets = new HashMap<FilterFunction, DocSet>(); // TODO: something better than a map? | |
FilterScoreGroup filters; | |
public FiltersFunctionScoreQuery(Query subQuery, ScoreMode scoreMode, FilterFunction[] filterFunctions) { | |
@@ -132,12 +220,6 @@ | |
return subQuery; | |
} | |
- // 26-Dec-2011 NOTE: breaking API change from getFilterFunctions() : FilterFunction -- which was unused method in | |
- // ES code base, but not necessarily unused for users of ES | |
- public FilterScoreGroup getFilterFunctions() { | |
- return filters; | |
- } | |
- | |
@Override | |
public Query rewrite(IndexReader reader) throws IOException { | |
Query newQ = subQuery.rewrite(reader); | |
@@ -195,24 +277,10 @@ | |
if (subQueryScorer == null) { | |
return null; | |
} | |
- initFilterFunctionDocSets(reader, filters); | |
- return new CustomBoostFactorScorer(getSimilarity(searcher), this, subQueryScorer, filters, docSets); | |
+ filters.setNextReader(reader); | |
+ return new CustomBoostFactorScorer(getSimilarity(searcher), this, subQueryScorer, filters); | |
} | |
- private void initFilterFunctionDocSets(IndexReader reader, FilterScoreGroup group) throws IOException { | |
- for (FilterItem item : group.filterItems) { | |
- if (item instanceof FilterScoreGroup) { | |
- initFilterFunctionDocSets(reader, (FilterScoreGroup) item); | |
- } else if (item instanceof FilterFunction) { | |
- FilterFunction filterFunction = (FilterFunction) item; | |
- if (!docSets.containsKey(filterFunction)) { | |
- filterFunction.function.setNextReader(reader); | |
- docSets.put(filterFunction, DocSets.convert(reader, filterFunction.filter.getDocIdSet(reader))); | |
- } | |
- } | |
- } | |
- } | |
- | |
@Override | |
public Explanation explain(IndexReader reader, int doc) throws IOException { | |
Explanation subQueryExpl = subQueryWeight.explain(reader, doc); | |
@@ -344,16 +412,12 @@ | |
private final float subQueryWeight; | |
private final Scorer scorer; | |
private final FilterScoreGroup filters; | |
- private final Map<FilterFunction, DocSet> docSets; | |
- | |
- private CustomBoostFactorScorer(Similarity similarity, CustomBoostFactorWeight w, Scorer scorer, | |
- FilterScoreGroup filters, Map<FilterFunction, DocSet> docSets) throws IOException { | |
+ private CustomBoostFactorScorer(Similarity similarity, CustomBoostFactorWeight w, Scorer scorer, FilterScoreGroup filters) throws IOException { | |
super(similarity); | |
this.subQueryWeight = w.getValue(); | |
this.scorer = scorer; | |
this.filters = filters; | |
- this.docSets = docSets; | |
} | |
@Override | |
@@ -373,101 +437,12 @@ | |
@Override | |
public float score() throws IOException { | |
- int docId = scorer.docID(); | |
- float score = scorer.score(); | |
- score = score(filters, docId, score); | |
- return subQueryWeight * score; | |
+ float queryScore = scorer.score(); | |
+ float factor = filters.factor(scorer.docID()); | |
+ if (!Float.isNaN(factor)) { | |
+ queryScore *= factor; | |
- } | |
+ } | |
- | |
- private float score(FilterScoreGroup group, int docId, float score) { | |
- float newScore = scoreOnlyMatchingFilters(group, docId, score); | |
- return newScore != Float.NEGATIVE_INFINITY ? newScore : score; | |
- } | |
- | |
- private float scoreOnlyMatchingFilters(FilterScoreGroup group, int docId, float subqueryScore) { | |
- if (group.scoreMode == ScoreMode.First) { | |
- for (FilterItem item : group.filterItems) { | |
- if (item instanceof FilterScoreGroup) { | |
- float tempScore = scoreOnlyMatchingFilters((FilterScoreGroup) item, docId, subqueryScore); | |
- if (tempScore != Float.NEGATIVE_INFINITY) { | |
- return tempScore; | |
- } | |
- } else { | |
- FilterFunction filterFunction = (FilterFunction) item; | |
- DocSet docSet = docSets.get(filterFunction); | |
- if (docSet != null && docSet.get(docId)) { | |
- return filterFunction.function.score(docId, subqueryScore); | |
- } | |
- } | |
- } | |
- } else if (group.scoreMode == ScoreMode.Max) { | |
- float maxScore = Float.NEGATIVE_INFINITY; | |
- for (FilterItem item : group.filterItems) { | |
- if (item instanceof FilterScoreGroup) { | |
- maxScore = Math.max(scoreOnlyMatchingFilters((FilterScoreGroup) item, docId, subqueryScore), maxScore); | |
- } else { | |
- FilterFunction filterFunction = (FilterFunction) item; | |
- DocSet docSet = docSets.get(filterFunction); | |
- if (docSet != null && docSet.get(docId)) { | |
- maxScore = Math.max(filterFunction.function.score(docId, subqueryScore), maxScore); | |
- } | |
- } | |
- } | |
- if (maxScore != Float.NEGATIVE_INFINITY) { | |
- return maxScore; | |
- } | |
- } else if (group.scoreMode == ScoreMode.Min) { | |
- float minScore = Float.POSITIVE_INFINITY; | |
- for (FilterItem item : group.filterItems) { | |
- if (item instanceof FilterScoreGroup) { | |
- float tempScore = scoreOnlyMatchingFilters((FilterScoreGroup) item, docId, subqueryScore); | |
- if (tempScore != Float.NEGATIVE_INFINITY) { | |
- minScore = Math.min(tempScore, minScore); | |
- } | |
- } else { | |
- FilterFunction filterFunction = (FilterFunction) item; | |
- DocSet docSet = docSets.get(filterFunction); | |
- if (docSet != null && docSet.get(docId)) { | |
- minScore = Math.min(filterFunction.function.score(docId, subqueryScore), minScore); | |
- } | |
- } | |
- } | |
- if (minScore != Float.POSITIVE_INFINITY) { | |
- return minScore; | |
- } | |
- } else { // Avg / Total / multiply | |
- float totalScore = 0.0f; | |
- float multiplicativeScore = 1.0f; | |
- int count = 0; | |
- for (FilterItem item : group.filterItems) { | |
- if (item instanceof FilterScoreGroup) { | |
- float tempScore = scoreOnlyMatchingFilters((FilterScoreGroup) item, docId, subqueryScore); | |
- if (tempScore != Float.NEGATIVE_INFINITY) { | |
- totalScore += tempScore; | |
- multiplicativeScore *= tempScore; | |
- count++; | |
- } | |
- } else { | |
- FilterFunction filterFunction = (FilterFunction) item; | |
- DocSet docSet = docSets.get(filterFunction); | |
- if (docSet != null && docSet.get(docId)) { | |
- float tempScore = filterFunction.function.score(docId, subqueryScore); | |
- totalScore += tempScore; | |
- multiplicativeScore *= tempScore; | |
- count++; | |
- } | |
- } | |
- } | |
- if (count != 0) { | |
- if (group.scoreMode == ScoreMode.Avg) { | |
- return totalScore / count; | |
- } else if (group.scoreMode == ScoreMode.Multiply) { | |
- return multiplicativeScore; | |
- } | |
- return totalScore; | |
- } | |
- } | |
- return Float.NEGATIVE_INFINITY; | |
+ return subQueryWeight * queryScore; | |
} | |
} | |
Index: src/main/java/org/elasticsearch/common/lucene/search/function/BoostScoreFunction.java | |
IDEA additional info: | |
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP | |
<+>UTF-8 | |
=================================================================== | |
--- src/main/java/org/elasticsearch/common/lucene/search/function/BoostScoreFunction.java (date 1324946452000) | |
+++ src/main/java/org/elasticsearch/common/lucene/search/function/BoostScoreFunction.java (revision ) | |
@@ -49,6 +49,11 @@ | |
} | |
@Override | |
+ public float factor(int docId) { | |
+ return boost; | |
+ } | |
+ | |
+ @Override | |
public Explanation explain(int docId, Explanation subQueryExpl) { | |
Explanation exp = new Explanation(boost * subQueryExpl.getValue(), "static boost function: product of:"); | |
exp.addDetail(subQueryExpl); | |
@@ -71,5 +76,10 @@ | |
@Override | |
public int hashCode() { | |
return (boost != +0.0f ? Float.floatToIntBits(boost) : 0); | |
+ } | |
+ | |
+ @Override | |
+ public String toString() { | |
+ return "boost(" + boost + ")"; | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment