Created
February 16, 2016 03:02
-
-
Save jcjohnson/7dd6bfc107b26207b1e6 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
require 'torch' | |
require 'nn' | |
require 'cutorch' | |
require 'cunn' | |
require 'cudnn' | |
require 'optim' | |
require 'hdf5' | |
require 'image' | |
local cmd = torch.CmdLine() | |
cmd:option('-data_h5_file', 'cv/TinyImageNetA.h5') | |
cmd:option('-output_h5_file', 'cv/torch_model.h5') | |
cmd:option('-output_t7_file', 'cv/torch_model.t7') | |
local opt = cmd:parse(arg) | |
local function load_data(data_file) | |
print(data_file) | |
local f = hdf5.open(data_file) | |
local dset = {} | |
dset.X_train = f:read('/X_train'):all() | |
dset.y_train = f:read('/y_train'):all() + 1 | |
dset.X_val = f:read('/X_val'):all() | |
dset.y_val = f:read('/y_val'):all() + 1 | |
f:close() | |
return dset | |
end | |
local function get_minibatch(X, y, batch_size) | |
local mask = torch.LongTensor(batch_size):random(X:size(1)) | |
local X_batch = X:index(1, mask) | |
local y_batch = y:index(1, mask) | |
return X_batch, y_batch | |
end | |
function random_crop_flip(batch) | |
local crop_size = 16 | |
local H, W = batch:size(3), batch:size(4) | |
local h, w = H - crop_size, W - crop_size | |
local x0 = torch.random(1 + crop_size) | |
local y0 = torch.random(1 + crop_size) | |
local cropped = batch[{{}, {}, {y0, y0 + h - 1}, {x0, x0 + w - 1}}] | |
if torch.random(2) == 1 then | |
for i = 1, cropped:size(1) do | |
cropped[i] = image.hflip(cropped[i]) | |
end | |
end | |
return cropped | |
end | |
function random_flip(batch) | |
local flipped = batch:clone() | |
for i = 1, batch:size(1) do | |
if torch.random(2) == 1 then | |
flipped[i] = image.hflip(flipped[i]) | |
end | |
end | |
return flipped | |
end | |
function center_crop(batch) | |
local crop_size = 16 | |
local H, W = batch:size(3), batch:size(4) | |
local h, w = H - crop_size, W - crop_size | |
local x0, y0 = crop_size / 2, crop_size / 2 | |
local cropped = batch[{{}, {}, {y0, y0 + h - 1}, {x0, x0 + w - 1}}] | |
return cropped:clone() | |
end | |
local function check_accuracy(X, y, model, batch_size) | |
model:evaluate() | |
local num_correct = 0 | |
local num_tested = 0 | |
for t = 1, 20 do | |
local X_batch, y_batch = get_minibatch(X, y, batch_size) | |
-- X_batch = center_crop(X_batch) | |
X_batch = X_batch:cuda() | |
y_batch = y_batch:cuda() | |
local scores = model:forward(X_batch) | |
local _, y_pred = scores:max(2) | |
num_correct = num_correct + torch.eq(y_pred, y_batch):sum() | |
num_tested = num_tested + batch_size | |
end | |
return num_correct / num_tested | |
end | |
local function build_model() | |
--[[ | |
-- This is what the Python code expects now | |
local num_filters = {64, 64, 128, 128, 256, 256, 512} | |
local filter_sizes = {5, 3, 3, 3, 3, 3, 3} | |
local filter_strides = {2, 1, 2, 1, 2, 1, 2} | |
local num_classes = 100 | |
local hidden_dim = 1024 | |
local image_size = 64 - 16 | |
--]] | |
local num_filters = {64, 64, 128, 128, 256, 256, 512, 512, 1024} | |
local filter_sizes = {5, 3, 3, 3, 3, 3, 3, 3, 3} | |
local filter_strides = {2, 1, 2, 1, 2, 1, 2, 1, 2} | |
local dropout = {0.1, 0.1, 0.2, 0.2, 0.3, 0.3, 0.4, 0.4, 0.5} | |
local num_classes = 100 | |
local hidden_dim = 512 | |
local image_size = 64 | |
local prev_dim = 3 | |
local cur_size = image_size | |
local model = nn.Sequential() | |
for i = 1, #num_filters do | |
local next_dim = num_filters[i] | |
local size = filter_sizes[i] | |
local stride = filter_strides[i] | |
local pad = (size - 1) / 2 | |
model:add(nn.SpatialConvolution(prev_dim, next_dim, | |
size, size, stride, stride, pad, pad)) | |
model:add(nn.SpatialBatchNormalization(next_dim)) | |
model:add(nn.ReLU(true)) | |
model:add(nn.Dropout(dropout[i])) | |
prev_dim = next_dim | |
if stride == 2 then | |
cur_size = cur_size / 2 | |
end | |
end | |
local fan_in = cur_size * cur_size * num_filters[#num_filters] | |
model:add(nn.View(-1):setNumInputDims(3)) | |
model:add(nn.Linear(fan_in, hidden_dim)) | |
model:add(nn.BatchNormalization(hidden_dim)) | |
model:add(nn.Dropout(0.8)) | |
model:add(nn.ReLU(true)) | |
model:add(nn.Linear(hidden_dim, num_classes)) | |
return model | |
end | |
local function save_model(model, out_file) | |
local next_weight_idx = 1 | |
local next_bn_idx = 1 | |
local f = hdf5.open(out_file, 'w') | |
for i = 1, #model do | |
local layer = model:get(i) | |
if torch.isTypeOf(layer, nn.SpatialConvolution) or | |
torch.isTypeOf(layer, nn.Linear) then | |
f:write(string.format('/W%d', next_weight_idx), layer.weight:float()) | |
f:write(string.format('/b%d', next_weight_idx), layer.bias:float()) | |
next_weight_idx = next_weight_idx + 1 | |
elseif torch.isTypeOf(layer, nn.SpatialBatchNormalization) or | |
torch.isTypeOf(layer, nn.BatchNormalization) then | |
f:write(string.format('/gamma%d', next_bn_idx), layer.weight:float()) | |
f:write(string.format('/beta%d', next_bn_idx), layer.bias:float()) | |
f:write(string.format('/running_mean%d', next_bn_idx), layer.running_mean:float()) | |
if torch.isTypeOf(layer, nn.BatchNormalization) then | |
f:write(string.format('/running_var%d', next_bn_idx), | |
torch.pow(layer.running_std, -2.0):add(-layer.eps):float()) | |
elseif torch.isTypeOf(layer, nn.SpatialBatchNormalization) then | |
f:write(string.format('/running_var%d', next_bn_idx), | |
layer.running_var:float()) | |
end | |
next_bn_idx = next_bn_idx + 1 | |
end | |
end | |
f:close() | |
end | |
local dset = load_data(opt.data_h5_file) | |
local model = build_model() | |
print(model) | |
cudnn.convert(model, cudnn) | |
model:cuda() | |
model:training() | |
local crit = nn.CrossEntropyCriterion():cuda() | |
local num_iterations = 120000 | |
local reg = 1e-3 | |
local batch_size = 50 | |
local config = { | |
learningRate=1e-1, | |
} | |
local t = 0 | |
local params, gradParams = model:getParameters() | |
local function f(w) | |
gradParams:zero() | |
local X_batch, y_batch = get_minibatch(dset.X_train, dset.y_train, batch_size) | |
-- X_batch = random_crop_flip(X_batch):cuda() | |
X_batch = random_flip(X_batch):cuda() | |
y_batch = y_batch:cuda() | |
assert(w == params) | |
local scores = model:forward(X_batch) | |
local data_loss = crit:forward(scores, y_batch) | |
local dscores = crit:backward(scores, y_batch) | |
model:backward(X_batch, dscores) | |
-- add regularization | |
gradParams:add(reg, params) | |
if t % 100 == 0 then | |
print(t, data_loss, torch.abs(gradParams):mean()) | |
end | |
return data_loss, gradParams | |
end | |
while t < num_iterations do | |
t = t + 1 | |
-- optim.adam(f, params, config) | |
optim.sgd(f, params, config) | |
-- Check training and validation accuracy once in a while | |
if t % 200 == 0 then | |
local train_acc = check_accuracy(dset.X_train, dset.y_train, model, batch_size) | |
local val_acc = check_accuracy(dset.X_val, dset.y_val, model, batch_size) | |
print('train acc: ', train_acc, 'val_acc: ', val_acc) | |
model:training() | |
end | |
if t % 7500 == 0 then | |
config.learningRate = config.learningRate / 1.5 | |
end | |
--[[ | |
-- This schedule works well for adam, starting from 1e-3 | |
if t % 4000 == 0 then | |
config.learningRate = config.learningRate / 2.0 | |
end | |
--]] | |
end | |
save_model(model, opt.output_h5_file) | |
torch.save(opt.output_t7_file, model) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment