Skip to content

Instantly share code, notes, and snippets.

@mocobeta
Last active December 17, 2015 22:49
Show Gist options
  • Save mocobeta/5685087 to your computer and use it in GitHub Desktop.
Save mocobeta/5685087 to your computer and use it in GitHub Desktop.
Lucene でカスタムソートを実装するサンプル
/**
* 以下は、Apache Softoware Licence v2.0 の元に頒布されているコードに一部改変を加えたものです。
* http://www.apache.org/licenses/LICENSE-2.0.txt
*/
import java.io.IOException;
import org.apache.lucene.index.AtomicReaderContext;
import org.apache.lucene.search.FieldCache;
import org.apache.lucene.search.FieldCache.Ints;
import org.apache.lucene.search.FieldComparator;
import org.apache.lucene.search.FieldComparatorSource;
public class DistanceComparatorSource extends FieldComparatorSource {
private int x; // 基準点のX座標
private int y; // 基準点のY座標
public DistanceComparatorSource(int x, int y) {
// 距離計算の基準となる点の座標値
this.x = x;
this.y = y;
}
@Override
/** 新規Comparatorを返す */
public FieldComparator<Float> newComparator(String fieldName, int numHits, int sortPos,
boolean reversed) throws IOException {
// DistanceScoreLookupComparatorを生成
return new DistanceScoreDocLookupComparator(fieldName, numHits);
}
/** 距離計算を行うためのComparator */
private class DistanceScoreDocLookupComparator extends FieldComparator<Float> {
private Ints xDoc; // 登録されているドキュメントのX座標フィールド値
private Ints yDoc; // 登録されているドキュメントのY座標フィールド値
private float[] values; // ソートに使用する値(ここでは基準点(x,y)からの距離)を保持しておく配列
private float bottom; // 最後尾の値を保持しておく変数
String fieldName;
public DistanceScoreDocLookupComparator(String fieldName, int numHits) {
values = new float[numHits];
this.fieldName = fieldName;
}
@Override
public FieldComparator<Float> setNextReader(AtomicReaderContext cont)
throws IOException {
// フィールドキャッシュからX座標、Y座標の値をそれぞれ読み出す
xDoc = FieldCache.DEFAULT.getInts(cont.reader(), "x", false);
yDoc = FieldCache.DEFAULT.getInts(cont.reader(), "y", false);
return this;
}
@Override
public int compare(int slot1, int slot2) {
// slot1 番目と slot2番目の値(距離)を比較する
if (values[slot1] < values[slot2]) return -1;
if (values[slot1] > values[slot2]) return 1;
return 0;
}
@Override
public void setBottom(int slot) {
bottom = values[slot];
}
@Override
public int compareBottom(int doc) throws IOException {
// 基準点(x,y)と指定されたドキュメントとの距離を求め、現在の最後尾の値と比較
float docDistance = getDistance(doc);
if (bottom < docDistance) return -1;
if (bottom > docDistance) return 1;
return 0;
}
private float getDistance(int doc) {
// 基準点(x,y)と指定されたドキュメント(の表す点)との距離を計算する
int deltax = xDoc.get(doc) - x;
int deltay = yDoc.get(doc) - y;
return (float) Math.sqrt(deltax * deltax + deltay * deltay);
}
@Override
public void copy(int slot, int doc) throws IOException {
// values配列に、getDistance()メソッドで計算される値を設定する
values[slot] = getDistance(doc);
}
@Override
public Float value(int slot) {
// slot で指定される値を返却する
return values[slot];
}
@Override
public int compareDocToValue(int doc, Float val) throws IOException {
float docDistance = getDistance(doc);
if (docDistance < val) return -1;
if (docDistance > val) return 1;
return 0;
}
}
}
/**
* 以下は、Apache Softoware Licence v2.0 の元に頒布されているコードに一部改変を加えたものです。
* http://www.apache.org/licenses/LICENSE-2.0.txt
*/
import static org.junit.Assert.*;
import java.io.IOException;
import org.apache.lucene.analysis.core.WhitespaceAnalyzer;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.IntField;
import org.apache.lucene.document.StringField;
import org.apache.lucene.document.Field.Store;
import org.apache.lucene.index.DirectoryReader;
import org.apache.lucene.index.IndexWriter;
import org.apache.lucene.index.IndexWriterConfig;
import org.apache.lucene.index.Term;
import org.apache.lucene.search.FieldDoc;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.Sort;
import org.apache.lucene.search.SortField;
import org.apache.lucene.search.TermQuery;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.store.RAMDirectory;
import org.apache.lucene.util.Version;
import org.junit.Before;
import org.junit.Test;
public class DistanceSortingTest {
private RAMDirectory directory;
@Before
public void setUp() throws IOException {
directory = new RAMDirectory();
IndexWriterConfig conf = new IndexWriterConfig(Version.LUCENE_43, new WhitespaceAnalyzer(Version.LUCENE_43));
IndexWriter writer = new IndexWriter(directory, conf);
// 店名、種別、座標の情報をドキュメントとしてインデキシング
addPoint(writer, "ここす", "restaurant", 5, 9);
addPoint(writer, "でにーず", "restaurant", 1, 2);
addPoint(writer, "じょなさん", "restaurant", 9, 6);
addPoint(writer, "がすと", "restaurant", 3, 8);
addPoint(writer, "ばーみやん", "restaurant", 4, 3);
addPoint(writer, "すたば", "cafe", 2, 1);
addPoint(writer, "どとーる", "cafe", 4, 4);
addPoint(writer, "たりーず", "cafe", 8, 3);
writer.close();
}
private void addPoint(IndexWriter writer, String name, String type, int x, int y)
throws IOException {
// 店名、種別、X座標、Y座標をドキュメントとしてインデキシング
Document doc = new Document();
doc.add(new StringField("name", name, Store.YES));
doc.add(new StringField("type", type, Store.YES));
doc.add(new IntField("x", x, Store.YES));
doc.add(new IntField("y", y, Store.YES));
writer.addDocument(doc);
}
@Test
public void testNearestRestaurantToWork() throws IOException {
// レストランを検索し、職場(0,0)から近い順にソートする
IndexSearcher searcher = new IndexSearcher(DirectoryReader.open(directory));
Query query = new TermQuery(new Term("type", "restaurant"));
Sort sort = new Sort(new SortField("location", new DistanceComparatorSource(0, 0)));
TopDocs hits = searcher.search(query, null, 10, sort);
// (0,0)から直線距離が近い順に並んでいること、また正しく距離が計算されていることを確認
FieldDoc fieldDoc = (FieldDoc) hits.scoreDocs[0];
assertEquals("でにーず", searcher.doc(fieldDoc.doc).get("name"));
assertEquals((float) Math.sqrt(1*1 + 2*2), fieldDoc.fields[0]);
fieldDoc = (FieldDoc) hits.scoreDocs[1];
assertEquals("ばーみやん", searcher.doc(fieldDoc.doc).get("name"));
assertEquals((float) Math.sqrt(4*4 + 3*3), fieldDoc.fields[0]);
fieldDoc = (FieldDoc) hits.scoreDocs[2];
assertEquals("がすと", searcher.doc(fieldDoc.doc).get("name"));
assertEquals((float) Math.sqrt(3*3 + 8*8), fieldDoc.fields[0]);
fieldDoc = (FieldDoc) hits.scoreDocs[3];
assertEquals("ここす", searcher.doc(fieldDoc.doc).get("name"));
assertEquals((float) Math.sqrt(5*5 + 9*9), fieldDoc.fields[0]);
fieldDoc = (FieldDoc) hits.scoreDocs[4];
assertEquals("じょなさん", searcher.doc(fieldDoc.doc).get("name"));
assertEquals((float) Math.sqrt(9*9 + 6*6), fieldDoc.fields[0]);
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment