Skip to content

Instantly share code, notes, and snippets.

@ppat
Created October 24, 2013 14:49
Show Gist options
  • Save ppat/7138638 to your computer and use it in GitHub Desktop.
Save ppat/7138638 to your computer and use it in GitHub Desktop.
public class JaccardScoreScript extends AbstractDoubleSearchScript
{
private final ESLogger logger = Loggers.getLogger(JaccardScoreScript.class);
private final double coefficient;
private final String docItemField;
private final String docItemScoreField;
private final List<Long> queryItems;
private final List<Double> queryItemScores;
private final boolean zero;
public JaccardScoreScript(@Nullable Map<String, Object> params)
{
if (isParamsValid(params)) {
zero = false;
coefficient = Double.valueOf(params.get("coefficient").toString());
docItemField = params.get("docItemIdField").toString();
docItemScoreField = params.get("docItemScoreField").toString();
queryItems = asLongs(splitAsList(params.get("queryItemIds").toString()));
queryItemScores = asDoubles(splitAsList(params.get("queryItemScores").toString()));
} else {
logger.warn(String.format("invalid params in query: %s", params.toString()));
zero = true;
coefficient = 0.0;
docItemField = null;
docItemScoreField = null;
queryItems = null;
queryItemScores = null;
}
}
private boolean isParamsValid(Map<String, Object> params)
{
return
(params.containsKey("docItemIdField") && Strings.hasText(params.get("docItemIdField").toString())) &&
(params.containsKey("docItemScoreField") && Strings.hasText(params.get("docItemScoreField").toString())) &&
(params.containsKey("queryItemIds") && Strings.hasText(params.get("queryItemIds").toString())) &&
(params.containsKey("queryItemScores") && Strings.hasText(params.get("queryItemScores").toString())) &&
(params.containsKey("coefficient") && Strings.hasText(params.get("coefficient").toString()));
}
@Override
public double runAsDouble()
{
if (zero) return 0.0;
if (queryItems.size() != queryItemScores.size()) return 0.0;
if (queryItems.isEmpty()) return 0.0;
List<Long> docItems = asLongs(splitAsList(docFieldStrings(docItemField).getValue()));
List<Double> docItemScores = asDoubles(splitAsList(docFieldStrings(docItemScoreField).getValue()));
if (docItems.isEmpty()) {
if (logger.isDebugEnabled()) logger.debug(String.format("doc does not have values for %s, skipping doc", docItemField));
return 0.0;
}
if (docItems.size() != docItemScores.size()) {
logger.warn(String.format("doc does not have corresponding scores (%s) for corresponding items (%s)",
docItemScoreField,
docItemField));
return 0.0;
}
double intersectionScore = intersectionScore(docItems, docItemScores, queryItems, queryItemScores);
double unionScore = unionScore(docItems, docItemScores, queryItems, queryItemScores);
double score = (intersectionScore / unionScore) * coefficient;
if (logger.isDebugEnabled()) {
logger.info(String.format(
"\r\n\tcoefficient: %s, queryItems: %s, queryItemScores: %s, \r\n\tdocItems: %s, docItemScores: %s, \r\n\tintersectionScore: %s, unionScore: %s, score: %s",
Double.toString(coefficient),
queryItems.toString(),
queryItemScores.toString(),
docItems.toString(),
docItemScores.toString(),
Double.toString(intersectionScore),
Double.toString(unionScore),
Double.toString(score)));
}
return score;
}
private <T> Collection<T> intersection(List<T> a,
List<T> b)
{
Set<T> set_a = asSet(a);
Set<T> set_b = asSet(b);
Set<T> intersection = new HashSet<T>();
intersection.addAll(set_a);
intersection.retainAll(set_b);
return intersection;
}
private <T> Collection<T> union(List<T> a,
List<T> b)
{
Set<T> set_a = asSet(a);
Set<T> set_b = asSet(b);
Set<T> union = new HashSet<T>();
union.addAll(set_a);
union.addAll(set_b);
return union;
}
private <T> Set<T> asSet(List<T> values)
{
return new HashSet<T>(values);
}
private List<String> splitAsList(String s)
{
if (Strings.hasText(s))
return Arrays.asList(s.split(","));
else
return Collections.emptyList();
}
private List<Long> asLongs(List<String> values)
{
ArrayList<Long> list = new ArrayList<Long>(values.size());
for (String s : values) {
list.add(Long.valueOf(Strings.trimWhitespace(s)));
}
return list;
}
private List<Double> asDoubles(List<String> values)
{
ArrayList<Double> list = new ArrayList<Double>(values.size());
for (String s : values) {
list.add(Double.valueOf(Strings.trimWhitespace(s)));
}
return list;
}
/* protected for testing */
protected double intersectionScore(List<Long> docItems,
List<Double> docItemScores,
List<Long> queryItems,
List<Double> queryItemScores)
{
Collection<Long> intersection = intersection(docItems, queryItems);
double intersectionScore = 0.0;
for (Long i : intersection) {
double docItemScore = docItemScores.get(docItems.indexOf(i));
double queryItemScore = queryItemScores.get(queryItems.indexOf(i));
intersectionScore += (docItemScore > queryItemScore) ? queryItemScore : docItemScore;
}
return intersectionScore;
}
protected double unionScore(List<Long> docItems,
List<Double> docItemScores,
List<Long> queryItems,
List<Double> queryItemScores)
{
Collection<Long> union = union(docItems, queryItems);
double unionScore = 0.0;
for (Long u : union) {
double docItemScore = (docItems.indexOf(u) >= 0) ? docItemScores.get(docItems.indexOf(u)) : 0.0;
double queryItemScore = (queryItems.indexOf(u) >= 0) ? queryItemScores.get(queryItems.indexOf(u)) : 0.0;
unionScore += (docItemScore > queryItemScore) ? docItemScore : queryItemScore;
}
return unionScore;
}
}
{
"mappings": {
"thing": {
"properties": {
"x": {
"type": "multi_field",
"fields": {
"x": {
"type": "integer",
"store": "no",
"index": "analyzed",
"include_in_all": false
},
"delimitedIds": {
"type": "string",
"store": "yes",
"index": "no",
"include_in_all": false
},
"delimitedScores": {
"type": "string",
"store": "yes",
"index": "no",
"include_in_all": false
}
}
},
"y": {
"type": "multi_field",
"fields": {
"y": {
"type": "integer",
"store": "no",
"index": "analyzed",
"include_in_all": false
},
"delimitedIds": {
"type": "string",
"store": "yes",
"index": "no",
"include_in_all": false
},
"delimitedScores": {
"type": "string",
"store": "yes",
"index": "no",
"include_in_all": false
}
}
},
"z": {
"type": "multi_field",
"fields": {
"z": {
"type": "integer",
"store": "no",
"index": "analyzed",
"include_in_all": false
},
"delimitedIds": {
"type": "string",
"store": "yes",
"index": "no",
"include_in_all": false
},
"delimitedScores": {
"type": "string",
"store": "yes",
"index": "no",
"include_in_all": false
}
}
}
}
}
}
}
{
"function_score" : {
"query" : {
"match_all" : { }
},
"functions" : [ {
"filter" : {
"terms" : {
"x.x" : [ "103", "104", "134", "180" ],
"_cache" : true
}
},
"script_score" : {
"script" : "jaccardscorescript",
"lang" : "native",
"params" : {
"coefficient" : "7.6642",
"queryItemScores" : "1.0,1.0,0.4,0.24",
"queryItemIds" : "103,104,134,180",
"docItemScoreField" : "x.delimitedScores",
"docItemIdField" : "x.delimitedIds"
}
}
}, {
"filter" : {
"terms" : {
"y.y" : [ "100", "97", "94" ],
"_cache" : true
}
},
"script_score" : {
"script" : "jaccardscorescript",
"lang" : "native",
"params" : {
"coefficient" : "0.79509",
"queryItemScores" : "1.0,0.5,0.5",
"queryItemIds" : "100,97,94",
"docItemScoreField" : "y.delimitedScores",
"docItemIdField" : "y.delimitedIds"
}
}
}, {
"filter" : {
"terms" : {
"z.z" : [ "5", "10", "25", "1", "6", "21", "9", "2", "22", "7", "3", "23", "8", "4" ],
"_cache" : true
}
},
"script_score" : {
"script" : "jaccardscorescript",
"lang" : "native",
"params" : {
"coefficient" : "2.82619",
"queryItemScores" : "0.75,0.25,0.25,0.75,0.75,0.25,0.75,0.75,0.75,0.75,0.75,0.75,1.0,0.75",
"queryItemIds" : "5,10,25,1,6,21,9,2,22,7,3,23,8,4",
"docItemScoreField" : "z.delimitedScores",
"docItemIdField" : "z.delimitedIds"
}
}
}, {
"filter" : {
"match_all" : { }
},
"script_score" : {
"script" : "-1.3373"
}
} ],
"score_mode" : "sum"
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment