Last active
December 31, 2015 00:19
-
-
Save huichen/7906138 to your computer and use it in GitHub Desktop.
一百行实现一个分类器
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package main | |
import ( | |
"flag" | |
"github.com/huichen/mlf/contrib" | |
"github.com/huichen/mlf/eval" | |
"github.com/huichen/mlf/model" | |
"github.com/huichen/mlf/optimizer" | |
"log" | |
"os" | |
"runtime/pprof" | |
) | |
var ( | |
// 数据输入 | |
libsvm_file = flag.String("input", "", "libsvm格式的数据文件,训练数据") | |
test_file = flag.String("test", "", "libsvm格式的数据文件,测试数据") | |
// 模型输出 | |
model_file = flag.String("output", "model.mlf", "模型输出") | |
// 机器学习参数 | |
opt = flag.String("optimizer", "lbfgs", "优化器") | |
reg = flag.Int("regularization", 2, "正则化方法") | |
reg_factor = flag.Float64("reg_factor", float64(1), "正则化因子") | |
learning_rate = flag.Float64("learning_rate", float64(1), "学习率") | |
characteristic_time = flag.Float64("characteristic_time", float64(0), "学习率特征时间") | |
batch_size = flag.Int("batch_size", 0, | |
"梯度递降法的batch尺寸: 0为full batch, 1为stochastic, 其它值为mini batch") | |
delta = flag.Float64("delta", 1e-5, | |
"权重变化量和权重的比值(|dw|/|w|)小于此值时判定为收敛") | |
max_iter = flag.Int("max_iter", 0, "优化器最多迭代多少次") | |
folds = flag.Int("folds", 0, "N-交叉评价,值为零时不交叉评价") | |
// 性能测试输出 | |
cpuprofile = flag.String("cpuprofile", "", "处理器profile文件") | |
) | |
func main() { | |
flag.Parse() | |
// 载入训练集 | |
set := contrib.LoadLibSVMDataset(*libsvm_file) | |
// 设置训练器参数 | |
trainerOptions := model.TrainerOptions{ | |
Optimizer: optimizer.OptimizerOptions{ | |
OptimizerName: *opt, | |
RegularizationScheme: *reg, | |
RegularizationFactor: *reg_factor, | |
LearningRate: *learning_rate, | |
CharacteristicTime: *characteristic_time, | |
ConvergingDeltaWeight: *delta, | |
MaxIterations: *max_iter, | |
Options: &optimizer.GdOptions{ | |
BatchSize: *batch_size, | |
}}} | |
// 创建训练器 | |
trainer := model.NewMaxEntClassifierTrainer(trainerOptions) | |
// 打开处理器profile文件 | |
if *cpuprofile != "" { | |
f, err := os.Create(*cpuprofile) | |
if err != nil { | |
log.Fatal(err) | |
} | |
pprof.StartCPUProfile(f) | |
defer pprof.StopCPUProfile() | |
} | |
// 进行交叉评价 | |
evaluators := eval.NewEvaluators([]eval.Evaluator{ | |
&eval.PREvaluator{}, &eval.AccuracyEvaluator{}}) | |
if *folds != 0 { | |
result := eval.CrossValidate(trainer, set, evaluators, *folds) | |
log.Print(*folds, "-folds 交叉评价:") | |
log.Print("F1 = ", result.Metrics["fscore"]) | |
log.Print("准确度 = ", result.Metrics["accuracy"]) | |
return | |
} | |
// 在全部数据上训练模型 | |
model := trainer.Train(set) | |
model.Write(*model_file) | |
// 测试模型 | |
if *test_file != "" { | |
// 载入测试集 | |
testSet := contrib.LoadLibSVMDataset(*test_file) | |
// 在测试集上评价模型并输出结果 | |
result := evaluators.Evaluate(model, testSet) | |
log.Print("测试数据集评价:") | |
log.Print("精度 = ", result.Metrics["precision"]) | |
log.Print("召回率 = ", result.Metrics["recall"]) | |
log.Print("F1 = ", result.Metrics["fscore"]) | |
log.Print("准确度 = ", result.Metrics["accuracy"]) | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment