Skip to content

Instantly share code, notes, and snippets.

@tasukujp
Last active October 17, 2015 15:47
Show Gist options
  • Save tasukujp/2202fa8fc111329d6dd7 to your computer and use it in GitHub Desktop.
Save tasukujp/2202fa8fc111329d6dd7 to your computer and use it in GitHub Desktop.
import org.dmlc.xgboost4j.Booster;
import org.dmlc.xgboost4j.DMatrix;
import org.dmlc.xgboost4j.util.Trainer;
import org.dmlc.xgboost4j.util.XGBoostError;
import java.util.*;
public class XgboostSample {
public static void main(String[] args) throws XGBoostError {
// 学習データ
DMatrix trainMat = new DMatrix("/tmp/iris.train.scale");
// テストデータ
DMatrix testMat = new DMatrix("/tmp/iris.test.scale");
// パラメータの設定
Map<String, Object> param = new HashMap<String, Object>() {
{
put("objective", "multi:softmax");
put("num_class", 3);
}
};
List<Map.Entry<String, DMatrix>> watchs = new ArrayList<>();
watchs.add(new AbstractMap.SimpleEntry<>("train", trainMat));
watchs.add(new AbstractMap.SimpleEntry<>("test", testMat));
// モデルの構築
int round = 10;
Booster booster = Trainer.train(param.entrySet(), trainMat, round, watchs, null, null);
// 予測の実行
float[][] predicts = booster.predict(testMat);
System.out.println("predict length1: " + predicts.length);
System.out.println("predict length2: " + predicts[0].length);
System.out.println("error of predicts: " + eval(predicts, testMat));
}
private static float eval(float[][] predicts, DMatrix dmat) {
float error = 0f;
float[] labels;
try {
labels = dmat.getLabel();
} catch (XGBoostError ex) {
return -1f;
}
int nrow = predicts.length;
for(int i=0; i<nrow; i++) {
if(labels[i] != predicts[i][0]) {
error++;
}
}
System.out.println("error of count: " + error);
return error/labels.length;
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment