Last active
December 17, 2015 22:49
-
-
Save mocobeta/5685087 to your computer and use it in GitHub Desktop.
Lucene でカスタムソートを実装するサンプル
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/** | |
* 以下は、Apache Softoware Licence v2.0 の元に頒布されているコードに一部改変を加えたものです。 | |
* http://www.apache.org/licenses/LICENSE-2.0.txt | |
*/ | |
import java.io.IOException; | |
import org.apache.lucene.index.AtomicReaderContext; | |
import org.apache.lucene.search.FieldCache; | |
import org.apache.lucene.search.FieldCache.Ints; | |
import org.apache.lucene.search.FieldComparator; | |
import org.apache.lucene.search.FieldComparatorSource; | |
public class DistanceComparatorSource extends FieldComparatorSource { | |
private int x; // 基準点のX座標 | |
private int y; // 基準点のY座標 | |
public DistanceComparatorSource(int x, int y) { | |
// 距離計算の基準となる点の座標値 | |
this.x = x; | |
this.y = y; | |
} | |
@Override | |
/** 新規Comparatorを返す */ | |
public FieldComparator<Float> newComparator(String fieldName, int numHits, int sortPos, | |
boolean reversed) throws IOException { | |
// DistanceScoreLookupComparatorを生成 | |
return new DistanceScoreDocLookupComparator(fieldName, numHits); | |
} | |
/** 距離計算を行うためのComparator */ | |
private class DistanceScoreDocLookupComparator extends FieldComparator<Float> { | |
private Ints xDoc; // 登録されているドキュメントのX座標フィールド値 | |
private Ints yDoc; // 登録されているドキュメントのY座標フィールド値 | |
private float[] values; // ソートに使用する値(ここでは基準点(x,y)からの距離)を保持しておく配列 | |
private float bottom; // 最後尾の値を保持しておく変数 | |
String fieldName; | |
public DistanceScoreDocLookupComparator(String fieldName, int numHits) { | |
values = new float[numHits]; | |
this.fieldName = fieldName; | |
} | |
@Override | |
public FieldComparator<Float> setNextReader(AtomicReaderContext cont) | |
throws IOException { | |
// フィールドキャッシュからX座標、Y座標の値をそれぞれ読み出す | |
xDoc = FieldCache.DEFAULT.getInts(cont.reader(), "x", false); | |
yDoc = FieldCache.DEFAULT.getInts(cont.reader(), "y", false); | |
return this; | |
} | |
@Override | |
public int compare(int slot1, int slot2) { | |
// slot1 番目と slot2番目の値(距離)を比較する | |
if (values[slot1] < values[slot2]) return -1; | |
if (values[slot1] > values[slot2]) return 1; | |
return 0; | |
} | |
@Override | |
public void setBottom(int slot) { | |
bottom = values[slot]; | |
} | |
@Override | |
public int compareBottom(int doc) throws IOException { | |
// 基準点(x,y)と指定されたドキュメントとの距離を求め、現在の最後尾の値と比較 | |
float docDistance = getDistance(doc); | |
if (bottom < docDistance) return -1; | |
if (bottom > docDistance) return 1; | |
return 0; | |
} | |
private float getDistance(int doc) { | |
// 基準点(x,y)と指定されたドキュメント(の表す点)との距離を計算する | |
int deltax = xDoc.get(doc) - x; | |
int deltay = yDoc.get(doc) - y; | |
return (float) Math.sqrt(deltax * deltax + deltay * deltay); | |
} | |
@Override | |
public void copy(int slot, int doc) throws IOException { | |
// values配列に、getDistance()メソッドで計算される値を設定する | |
values[slot] = getDistance(doc); | |
} | |
@Override | |
public Float value(int slot) { | |
// slot で指定される値を返却する | |
return values[slot]; | |
} | |
@Override | |
public int compareDocToValue(int doc, Float val) throws IOException { | |
float docDistance = getDistance(doc); | |
if (docDistance < val) return -1; | |
if (docDistance > val) return 1; | |
return 0; | |
} | |
} | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/** | |
* 以下は、Apache Softoware Licence v2.0 の元に頒布されているコードに一部改変を加えたものです。 | |
* http://www.apache.org/licenses/LICENSE-2.0.txt | |
*/ | |
import static org.junit.Assert.*; | |
import java.io.IOException; | |
import org.apache.lucene.analysis.core.WhitespaceAnalyzer; | |
import org.apache.lucene.document.Document; | |
import org.apache.lucene.document.IntField; | |
import org.apache.lucene.document.StringField; | |
import org.apache.lucene.document.Field.Store; | |
import org.apache.lucene.index.DirectoryReader; | |
import org.apache.lucene.index.IndexWriter; | |
import org.apache.lucene.index.IndexWriterConfig; | |
import org.apache.lucene.index.Term; | |
import org.apache.lucene.search.FieldDoc; | |
import org.apache.lucene.search.IndexSearcher; | |
import org.apache.lucene.search.Query; | |
import org.apache.lucene.search.Sort; | |
import org.apache.lucene.search.SortField; | |
import org.apache.lucene.search.TermQuery; | |
import org.apache.lucene.search.TopDocs; | |
import org.apache.lucene.store.RAMDirectory; | |
import org.apache.lucene.util.Version; | |
import org.junit.Before; | |
import org.junit.Test; | |
public class DistanceSortingTest { | |
private RAMDirectory directory; | |
@Before | |
public void setUp() throws IOException { | |
directory = new RAMDirectory(); | |
IndexWriterConfig conf = new IndexWriterConfig(Version.LUCENE_43, new WhitespaceAnalyzer(Version.LUCENE_43)); | |
IndexWriter writer = new IndexWriter(directory, conf); | |
// 店名、種別、座標の情報をドキュメントとしてインデキシング | |
addPoint(writer, "ここす", "restaurant", 5, 9); | |
addPoint(writer, "でにーず", "restaurant", 1, 2); | |
addPoint(writer, "じょなさん", "restaurant", 9, 6); | |
addPoint(writer, "がすと", "restaurant", 3, 8); | |
addPoint(writer, "ばーみやん", "restaurant", 4, 3); | |
addPoint(writer, "すたば", "cafe", 2, 1); | |
addPoint(writer, "どとーる", "cafe", 4, 4); | |
addPoint(writer, "たりーず", "cafe", 8, 3); | |
writer.close(); | |
} | |
private void addPoint(IndexWriter writer, String name, String type, int x, int y) | |
throws IOException { | |
// 店名、種別、X座標、Y座標をドキュメントとしてインデキシング | |
Document doc = new Document(); | |
doc.add(new StringField("name", name, Store.YES)); | |
doc.add(new StringField("type", type, Store.YES)); | |
doc.add(new IntField("x", x, Store.YES)); | |
doc.add(new IntField("y", y, Store.YES)); | |
writer.addDocument(doc); | |
} | |
@Test | |
public void testNearestRestaurantToWork() throws IOException { | |
// レストランを検索し、職場(0,0)から近い順にソートする | |
IndexSearcher searcher = new IndexSearcher(DirectoryReader.open(directory)); | |
Query query = new TermQuery(new Term("type", "restaurant")); | |
Sort sort = new Sort(new SortField("location", new DistanceComparatorSource(0, 0))); | |
TopDocs hits = searcher.search(query, null, 10, sort); | |
// (0,0)から直線距離が近い順に並んでいること、また正しく距離が計算されていることを確認 | |
FieldDoc fieldDoc = (FieldDoc) hits.scoreDocs[0]; | |
assertEquals("でにーず", searcher.doc(fieldDoc.doc).get("name")); | |
assertEquals((float) Math.sqrt(1*1 + 2*2), fieldDoc.fields[0]); | |
fieldDoc = (FieldDoc) hits.scoreDocs[1]; | |
assertEquals("ばーみやん", searcher.doc(fieldDoc.doc).get("name")); | |
assertEquals((float) Math.sqrt(4*4 + 3*3), fieldDoc.fields[0]); | |
fieldDoc = (FieldDoc) hits.scoreDocs[2]; | |
assertEquals("がすと", searcher.doc(fieldDoc.doc).get("name")); | |
assertEquals((float) Math.sqrt(3*3 + 8*8), fieldDoc.fields[0]); | |
fieldDoc = (FieldDoc) hits.scoreDocs[3]; | |
assertEquals("ここす", searcher.doc(fieldDoc.doc).get("name")); | |
assertEquals((float) Math.sqrt(5*5 + 9*9), fieldDoc.fields[0]); | |
fieldDoc = (FieldDoc) hits.scoreDocs[4]; | |
assertEquals("じょなさん", searcher.doc(fieldDoc.doc).get("name")); | |
assertEquals((float) Math.sqrt(9*9 + 6*6), fieldDoc.fields[0]); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment