Created
September 22, 2012 09:43
-
-
Save komiya-atsushi/3765693 to your computer and use it in GitHub Desktop.
Double Array Trie の Java 実装を Map インタフェースでラップしたもの。
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import java.io.BufferedInputStream; | |
import java.io.BufferedOutputStream; | |
import java.io.DataInputStream; | |
import java.io.DataOutputStream; | |
import java.io.File; | |
import java.io.FileInputStream; | |
import java.io.FileOutputStream; | |
import java.io.IOException; | |
import java.util.ArrayList; | |
import java.util.Collection; | |
import java.util.Collections; | |
import java.util.List; | |
import java.util.Map; | |
import java.util.Set; | |
import java.util.Vector; | |
/** | |
* Double Array Trie の Java 実装を Map インタフェースでラップしたクラスです。 | |
* <p> | |
* Java 実装は以下を利用しています。<br /> | |
* http://nlp.ist.i.kyoto-u.ac.jp/member/murawaki/misc/index.html | |
* </p> | |
* | |
* <p> | |
* Map インタフェースのうち、以下のメソッドが実装されています。 | |
* <ul> | |
* <li><code>{@link Map#containsKey(Object)}</code></li> | |
* <li><code>{@link Map#get(Object)}</code></li> | |
* <li><code>{@link Map#isEmpty()}</code></li> | |
* <li><code>{@link Map#size()}</code></li> | |
* </ul> | |
* </p> | |
* | |
* <p> | |
* また、Trie 特有の操作として、以下のメソッドが実装されています。 | |
* <ul> | |
* <li><code>{@link #commonPrefixSearch(String)}</code></li> | |
* </ul> | |
* </p> | |
* | |
* @author KOMIYA Atsushi | |
*/ | |
public class DartsMap<V> implements Map<String, V> { | |
/** Trie の各種操作を提供するオブジェクト */ | |
private DoubleArrayTrie trie; | |
/** Trie に格納されている各キーに対応する値。キーとは要素番号で紐付けされています。 */ | |
private List<V> values; | |
/** | |
* DartsMap オブジェクトを生成して返却します。 | |
* | |
* @param keys | |
* キーとなる文字列を格納したリスト。このリストの内容をもとに Trie が構築されます。 | |
* @param values | |
* キーに対応する値。keys の順序に対応しています。keys と同じ長さである必要があります。 | |
* @return 生成された DartsMap オブジェクト | |
*/ | |
public static <V> DartsMap<V> create(List<String> keys, List<V> values) { | |
if (keys == null) { | |
throw new IllegalArgumentException("keys に null を指定することはできません。"); | |
} | |
if (values == null) { | |
throw new IllegalArgumentException("values に null を指定することはできません。"); | |
} | |
if (keys.size() != values.size()) { | |
throw new IllegalArgumentException( | |
"keys と values の長さは同じである必要があります。"); | |
} | |
DoubleArrayTrie trie = new DoubleArrayTrie(); | |
char[][] charsArray = new char[keys.size()][]; | |
int[] keyLengthes = new int[keys.size()]; | |
for (int i = 0; i < keys.size(); i++) { | |
String key = keys.get(i); | |
charsArray[i] = key.toCharArray(); | |
keyLengthes[i] = key.length(); | |
} | |
int[] indexes = new int[values.size()]; | |
for (int i = 0; i < indexes.length; i++) { | |
indexes[i] = i; | |
} | |
trie.build(charsArray, keyLengthes, indexes); | |
return new DartsMap<V>(trie, values); | |
} | |
private DartsMap(DoubleArrayTrie trie, List<V> values) { | |
this.trie = trie; | |
this.values = values; | |
} | |
/** | |
* 指定されたキーで共通接頭辞探索を行い、得られた値をリストに格納して返却します。 | |
* | |
* @param key | |
* @return | |
*/ | |
public List<V> commonPrefixSearch(String key) { | |
int[] indexes = new int[values.size()]; | |
int indexCount = trie.commonPrefixSearch(key.toCharArray(), indexes, | |
indexes.length, 0); | |
if (indexCount == 0) { | |
return Collections.emptyList(); | |
} | |
List<V> result = new ArrayList<V>(indexCount); | |
for (int i = 0; i < indexCount; i++) { | |
result.add(values.get(indexes[i])); | |
} | |
return result; | |
} | |
@Override | |
public int size() { | |
return values.size(); | |
} | |
@Override | |
public boolean isEmpty() { | |
return size() == 0; | |
} | |
@Override | |
public boolean containsKey(Object key) { | |
char[] chars = ((String) key).toCharArray(); | |
int index = trie.exactMatchSearch(chars, 0); | |
return index >= 0; | |
} | |
@Override | |
public V get(Object key) { | |
char[] chars = ((String) key).toCharArray(); | |
int index = trie.exactMatchSearch(chars, 0); | |
if (index < 0) { | |
return null; | |
} | |
return values.get(index); | |
} | |
// 以下は提供されないメソッド | |
@Override | |
public V remove(Object key) { | |
throw createUnsupportedOperationException(); | |
} | |
@Override | |
public void clear() { | |
throw createUnsupportedOperationException(); | |
} | |
@Override | |
public Set<String> keySet() { | |
throw createUnsupportedOperationException(); | |
} | |
@Override | |
public Collection<V> values() { | |
throw createUnsupportedOperationException(); | |
} | |
@Override | |
public boolean containsValue(Object value) { | |
throw createUnsupportedOperationException(); | |
} | |
@Override | |
public Set<Entry<String, V>> entrySet() { | |
throw createUnsupportedOperationException(); | |
} | |
@Override | |
public V put(String key, V value) { | |
throw createUnsupportedOperationException(); | |
} | |
@Override | |
public void putAll(Map<? extends String, ? extends V> m) { | |
throw createUnsupportedOperationException(); | |
} | |
/** | |
* DartsMap として実装を提供しない Map インタフェースのメソッドが呼ばれた時に | |
* 投げる実行時例外オブジェクトを生成して返却します。 | |
* <p> | |
* 返却される例外オブジェクトには、呼ばれたメソッドの名前が含まれます。 | |
* </p> | |
* | |
* @return 例外オブジェクト | |
*/ | |
private RuntimeException createUnsupportedOperationException() { | |
StackTraceElement elem = new Exception().getStackTrace()[1]; | |
String message = String.format("DartsMap#%s() は実装されていません。", | |
elem.getMethodName()); | |
return new UnsupportedOperationException(message); | |
} | |
} | |
/** | |
* DoubleArrayTrie: Java implementation of Darts (Double-ARray Trie System) | |
* | |
* <p> | |
* Copyright(C) 2001-2007 Taku Kudo <[email protected]><br /> | |
* Copyright(C) 2009 MURAWAKI Yugo <[email protected]> | |
* </p> | |
* | |
* <p> | |
* The contents of this file may be used under the terms of either of | |
* the GNU Lesser General Public License Version 2.1 or later (the | |
* "LGPL"), or the BSD License (the "BSD"). | |
* </p> | |
*/ | |
class DoubleArrayTrie { | |
private final static int BUF_SIZE = 16384; | |
private final static int UNIT_SIZE = 8; // size of int + int | |
private class Node { | |
int code; | |
int depth; | |
int left; | |
int right; | |
}; | |
private class Unit { | |
int base; | |
int check; | |
}; | |
private Unit array[]; | |
private int used[]; | |
private int size; | |
private int allocSize; | |
private char key[][]; | |
private int keySize; | |
private int length[]; | |
private int value[]; | |
private int progress; | |
private int nextCheckPos; | |
// boolean no_delete_; | |
int error_; | |
// int (*progressfunc_) (size_t, size_t); | |
// inline _resize expanded | |
private int resize(int newSize) { | |
Unit array2[] = new Unit[newSize]; | |
for (int i = 0; i < allocSize; i++) { | |
array2[i] = array[i]; | |
} | |
for (int i = allocSize; i < newSize; i++) { | |
Unit tmp = new Unit(); | |
tmp.base = 0; | |
tmp.check = 0; | |
array2[i] = tmp; | |
} | |
array = array2; | |
int used2[] = new int[newSize]; | |
for (int i = 0; i < allocSize; i++) { | |
used2[i] = used[i]; | |
} | |
for (int i = allocSize; i < newSize; i++) { | |
used2[i] = 0; | |
} | |
used = used2; | |
return allocSize = newSize; | |
} | |
private int fetch(Node parent, Vector<Node> siblings) { | |
if (error_ < 0) | |
return 0; | |
int prev = 0; | |
for (int i = parent.left; i < parent.right; i++) { | |
if ((length != null ? length[i] : key[i].length) < parent.depth) | |
continue; | |
char tmp[] = key[i]; | |
int cur = 0; | |
if ((length != null ? length[i] : key[i].length) != parent.depth) | |
cur = (int) tmp[parent.depth] + 1; | |
if (prev > cur) { | |
error_ = -3; | |
return 0; | |
} | |
if (cur != prev || siblings.size() == 0) { | |
Node tmp_node = new Node(); | |
tmp_node.depth = parent.depth + 1; | |
tmp_node.code = cur; | |
tmp_node.left = i; | |
if (siblings.size() != 0) | |
siblings.get(siblings.size() - 1).right = i; | |
siblings.add(tmp_node); | |
} | |
prev = cur; | |
} | |
if (siblings.size() != 0) | |
siblings.get(siblings.size() - 1).right = parent.right; | |
return siblings.size(); | |
} | |
private int insert(Vector<Node> siblings) { | |
if (error_ < 0) | |
return 0; | |
int begin = 0; | |
int pos = ((siblings.get(0).code + 1 > nextCheckPos) ? siblings.get(0).code + 1 | |
: nextCheckPos) - 1; | |
int nonzero_num = 0; | |
int first = 0; | |
if (allocSize <= pos) | |
resize(pos + 1); | |
outer: while (true) { | |
pos++; | |
if (allocSize <= pos) | |
resize(pos + 1); | |
if (array[pos].check != 0) { | |
nonzero_num++; | |
continue; | |
} else if (first == 0) { | |
nextCheckPos = pos; | |
first = 1; | |
} | |
begin = pos - siblings.get(0).code; | |
if (allocSize <= (begin + siblings.get(siblings.size() - 1).code)) { | |
// progress can be zero | |
double l = (1.05 > 1.0 * keySize / (progress + 1)) ? 1.05 : 1.0 | |
* keySize / (progress + 1); | |
resize((int) (allocSize * l)); | |
} | |
if (used[begin] != 0) | |
continue; | |
for (int i = 1; i < siblings.size(); i++) | |
if (array[begin + siblings.get(i).code].check != 0) | |
continue outer; | |
break; | |
} | |
// -- Simple heuristics -- | |
// if the percentage of non-empty contents in check between the | |
// index | |
// 'next_check_pos' and 'check' is greater than some constant value | |
// (e.g. 0.9), | |
// new 'next_check_pos' index is written by 'check'. | |
if (1.0 * nonzero_num / (pos - nextCheckPos + 1) >= 0.95) | |
nextCheckPos = pos; | |
used[begin] = 1; | |
size = (size > begin + siblings.get(siblings.size() - 1).code + 1) ? size | |
: begin + siblings.get(siblings.size() - 1).code + 1; | |
for (int i = 0; i < siblings.size(); i++) | |
array[begin + siblings.get(i).code].check = begin; | |
for (int i = 0; i < siblings.size(); i++) { | |
Vector<Node> new_siblings = new Vector<Node>(); | |
if (fetch(siblings.get(i), new_siblings) == 0) { | |
array[begin + siblings.get(i).code].base = (value != null) ? (-value[siblings | |
.get(i).left] - 1) : (-siblings.get(i).left - 1); | |
if (value != null && (-value[siblings.get(i).left] - 1) >= 0) { | |
error_ = -2; | |
return 0; | |
} | |
progress++; | |
// if (progress_func_) (*progress_func_) (progress, | |
// keySize); | |
} else { | |
int h = insert(new_siblings); | |
array[begin + siblings.get(i).code].base = h; | |
} | |
} | |
return begin; | |
} | |
public DoubleArrayTrie() { | |
array = null; | |
used = null; | |
size = 0; | |
allocSize = 0; | |
// no_delete_ = false; | |
error_ = 0; | |
} | |
// no deconstructor | |
// set_result omitted | |
// the search methods returns (the list of) the value(s) instead | |
// of (the list of) the pair(s) of value(s) and length(s) | |
// set_array omitted | |
// array omitted | |
void clear() { | |
// if (! no_delete_) | |
array = null; | |
used = null; | |
allocSize = 0; | |
size = 0; | |
// no_delete_ = false; | |
} | |
public int getUnitSize() { | |
return UNIT_SIZE; | |
} | |
public int getSize() { | |
return size; | |
} | |
public int getTotalSize() { | |
return size * UNIT_SIZE; | |
} | |
public int getNonzeroSize() { | |
int result = 0; | |
for (int i = 0; i < size; i++) | |
if (array[i].check != 0) | |
result++; | |
return result; | |
} | |
public int build(char key[][], int length[], int value[]) { | |
return build(key, length, value, key.length); | |
} | |
public int build(char _key[][], int _length[], int _value[], int _keySize) { | |
if (_keySize > _key.length || _key == null) | |
return 0; | |
// progress_func_ = progress_func; | |
key = _key; | |
length = _length; | |
keySize = _keySize; | |
value = _value; | |
progress = 0; | |
resize(8192); | |
array[0].base = 1; | |
nextCheckPos = 0; | |
Node root_node = new Node(); | |
root_node.left = 0; | |
root_node.right = keySize; | |
root_node.depth = 0; | |
Vector<Node> siblings = new Vector<Node>(); | |
fetch(root_node, siblings); | |
insert(siblings); | |
// size += (1 << 8 * 2) + 1; // ??? | |
// if (size >= allocSize) resize (size); | |
used = null; | |
return error_; | |
} | |
public void open(String fileName) throws IOException { | |
File file = new File(fileName); | |
size = (int) file.length() / UNIT_SIZE; | |
array = new Unit[size]; | |
DataInputStream is = null; | |
try { | |
is = new DataInputStream(new BufferedInputStream( | |
new FileInputStream(file), BUF_SIZE)); | |
for (int i = 0; i < array.length; i++) { | |
Unit tmp = new Unit(); | |
tmp.base = is.readInt(); | |
tmp.check = is.readInt(); | |
array[i] = tmp; | |
} | |
} finally { | |
if (is != null) | |
is.close(); | |
} | |
} | |
public void save(String fileName) throws IOException { | |
DataOutputStream out = null; | |
try { | |
out = new DataOutputStream(new BufferedOutputStream( | |
new FileOutputStream(fileName))); | |
for (int i = 0; i < size; i++) { | |
out.writeInt(array[i].base); | |
out.writeInt(array[i].check); | |
} | |
out.close(); | |
} finally { | |
if (out != null) | |
out.close(); | |
} | |
} | |
public int exactMatchSearch(char key[], int pos) { | |
return exactMatchSearch(key, pos, 0, 0); | |
} | |
public int exactMatchSearch(char key[], int pos, int len, int nodePos) { | |
if (len <= 0) | |
len = key.length; | |
if (nodePos <= 0) | |
nodePos = 0; | |
int result = -1; | |
int b = array[nodePos].base; | |
int p; | |
for (int i = pos; i < len; i++) { | |
p = b + (int) (key[i]) + 1; | |
if (b == array[p].check) | |
b = array[p].base; | |
else | |
return result; | |
} | |
p = b; | |
int n = array[p].base; | |
if (b == array[p].check && n < 0) { | |
result = -n - 1; | |
} | |
return result; | |
} | |
public int commonPrefixSearch(char key[], int result[], int resultLen, | |
int pos) { | |
return commonPrefixSearch(key, result, resultLen, pos, 0, 0); | |
} | |
public int commonPrefixSearch(char key[], int result[], int resultLen, | |
int pos, int len, int nodePos) { | |
if (len <= 0) | |
len = key.length; | |
if (nodePos <= 0) | |
nodePos = 0; | |
int b = array[nodePos].base; | |
int num = 0; | |
int n; | |
int p; | |
for (int i = pos; i < len; i++) { | |
p = b; | |
n = array[p].base; | |
if (b == array[p].check && n < 0) { | |
if (num < resultLen) | |
result[num] = -n - 1; | |
num++; | |
} | |
p = b + (int) (key[i]) + 1; | |
if (b == array[p].check) | |
b = array[p].base; | |
else | |
return num; | |
} | |
p = b; | |
n = array[p].base; | |
if (b == array[p].check && n < 0) { | |
if (num < resultLen) | |
result[num] = -n - 1; | |
num++; | |
} | |
return num; | |
} | |
// debug | |
public void dump() { | |
for (int i = 0; i < size; i++) { | |
System.err.println("i: " + i + " [" + array[i].base + ", " | |
+ array[i].check + "]"); | |
} | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import static org.junit.Assert.*; | |
import java.util.Arrays; | |
import java.util.List; | |
import org.junit.Test; | |
public class DartsMapTest { | |
private static final String[] KEYS = { "ALGOL", "ANSI", "ARCO", "ARPA", | |
"ARPANET", "ASCII" }; | |
private static final String[] VALUES; | |
static { | |
String[] values = new String[KEYS.length]; | |
for (int i = 0; i < KEYS.length; i++) { | |
values[i] = KEYS[i].toLowerCase(); | |
} | |
VALUES = values; | |
} | |
@Test | |
public void キーとして存在しない値を指定してgetしてみる() { | |
DartsMap<String> map = DartsMap.<String> create(Arrays.asList(KEYS), | |
Arrays.asList(VALUES)); | |
assertNull(map.get("APPARE")); | |
} | |
@Test | |
public void キーとして存在する値を指定してgetしてみる() { | |
DartsMap<String> map = DartsMap.<String> create(Arrays.asList(KEYS), | |
Arrays.asList(VALUES)); | |
assertEquals("arpa", map.get("ARPA")); | |
assertEquals("arpanet", map.get("ARPANET")); | |
assertEquals("ascii", map.get("ASCII")); | |
assertEquals("algol", map.get("ALGOL")); | |
assertEquals("ansi", map.get("ANSI")); | |
assertEquals("arco", map.get("ARCO")); | |
} | |
@Test | |
public void 接頭辞として存在しない文字列で共通接頭辞検索してみる() { | |
DartsMap<String> map = DartsMap.<String> create(Arrays.asList(KEYS), | |
Arrays.asList(VALUES)); | |
List<String> result = map.commonPrefixSearch("APPARE"); | |
assertEquals(0, result.size()); | |
} | |
@Test | |
public void 接頭辞として1つ存在する文字列で共通接頭辞検索してみる() { | |
DartsMap<String> map = DartsMap.<String> create(Arrays.asList(KEYS), | |
Arrays.asList(VALUES)); | |
List<String> result = map.commonPrefixSearch("ALGOLOGIC"); | |
assertEquals(1, result.size()); | |
assertEquals("algol", result.get(0)); | |
} | |
@Test | |
public void 接頭辞として2つ存在する文字列で共通接頭辞検索してみる() { | |
DartsMap<String> map = DartsMap.<String> create(Arrays.asList(KEYS), | |
Arrays.asList(VALUES)); | |
List<String> result = map.commonPrefixSearch("ARPANET, Internet"); | |
assertEquals(2, result.size()); | |
assertTrue(result.contains("arpa")); | |
assertTrue(result.contains("arpanet")); | |
} | |
@Test | |
public void キーに合致する文字列で共通接頭辞検索してみる() { | |
DartsMap<String> map = DartsMap.<String> create(Arrays.asList(KEYS), | |
Arrays.asList(VALUES)); | |
List<String> result = map.commonPrefixSearch("ASCII"); | |
assertEquals(1, result.size()); | |
assertEquals("ascii", result.get(0)); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment