Last active
November 30, 2016 01:57
-
-
Save innerlee/67f588332666abbab1601c0f0f848cf8 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using HDF5 | |
using JLD | |
using LIBLINEAR | |
println("> svm on open image") | |
# read features (2048,165659) | |
features = h5read("data/grand5_feature.h5", "global_pool")[1, 1, :, :] | |
ids = readdlm("data/redis_val_list_with_id.txt")[:, 1] | |
# "000026e7ee790996" "/m/01cbzq" 1.0 | |
raw_labels = readdlm("data/labels.csv", ',')[2:end,[1,3,4]] | |
label_dict = Dict() | |
label_dict_multiple = Dict() | |
# remove too frequent labels. threshoud by quantile | |
bad = [] | |
remove_frequent = length(ARGS) < 2 ? 0 : parse(Float64, ARGS[2]) #0, 0.9, 0.99 | |
if remove_frequent > 0 | |
bad = Set(readdlm("data/bad$remove_frequent.txt")[:]) | |
end | |
println("remove frequent: ", remove_frequent) | |
for i = 1:size(raw_labels, 1) | |
(id, label, score) = raw_labels[i, :] | |
label in bad && continue | |
if !haskey(label_dict, id) | |
label_dict[id] = score > 0 ? label : "unknown" | |
label_dict_multiple[id] = score > 0 ? [label] : [] | |
elseif score > 0 | |
label_dict[id] = label | |
push!(label_dict_multiple[id], label) | |
end | |
end | |
# get gt, | |
# remove_frequent = 0, 2881 of them are "unknown" | |
# remove_frequent = 0.99, 9565 of them are "unknown" | |
gt = map(x->get(label_dict, x, "unknown"), ids) | |
gt_m = map(x->get(label_dict_multiple, x, "unknown"), ids) | |
fil = gt .!= "unknown" | |
# 2048×165623 | |
features = features[:, fil] | |
# 165623 | |
gt = gt[fil] | |
gt_m = gt_m[fil] | |
# shuffle them | |
perm = shuffle(1:length(gt)) | |
features = features[:, perm] | |
gt = gt[perm] | |
gt_m = gt_m[perm] | |
# split train/test 50/50 | |
num_train = floor(Int, length(gt) / 2) | |
train_features = features[:, 1:num_train] | |
train_gt = gt[1:num_train] | |
test_features = features[:, num_train+1:end] | |
test_gt = gt[num_train+1:end] | |
test_gt_m = gt_m[num_train+1:end] | |
## parse args | |
length(ARGS) == 0 && push!(ARGS, "$(length(train_gt))") | |
train_num = eval(parse("$(ARGS[1])")) | |
train_features = train_features[:, 1:min(end, train_num)] | |
train_gt = train_gt[1:min(end, train_num)] | |
println("train size: ", size(train_features)) | |
println("test size: ", size(test_features)) | |
## train & test | |
tic() | |
m = linear_train(train_gt, train_features) | |
(pred, scores) = linear_predict(m, test_features) | |
toc() | |
top5list = [m.labels[sortperm(scores[:, i], rev=true)[1:5]] for i=1:size(scores, 2)] | |
println("top-1 accuracy: ", mean(pred[:] .== test_gt[:])) | |
println("top-5 accuracy: ", mean(in.(test_gt, top5list))) | |
println("top-1 tag recall: ", mean(in.(pred, test_gt_m))) | |
println("top-5 tag recall: ", mean(length(setdiff(test_gt_m[i], top5list[i])) != length(test_gt_m[i]) for i=1:length(test_gt_m))) | |
## results | |
# remove frequent: 0 | |
# train size: (2048,80000) | |
# test size: (2048,81389) | |
# elapsed time: 7735.106875664 seconds | |
# top-1 accuracy: 0.49107373232254975 | |
# top-5 accuracy: 0.7078966445097004 | |
# top-1 tag recall: 0.6163609332956542 | |
# top-5 tag recall: 0.8360835002273035 | |
# remove frequent: 0.99 | |
# train size: (2048,78047) | |
# test size: (2048,78047) | |
# elapsed time: 10830.214292915 seconds | |
# top-1 accuracy: 0.3900854613245865 | |
# top-5 accuracy: 0.6232270298666188 | |
# top-1 tag recall: 0.5067843735185209 | |
# top-5 tag recall: 0.7632195984470895 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment