Last active
April 20, 2016 11:09
-
-
Save jcjohnson/61d23297d6ee67b065e5 to your computer and use it in GitHub Desktop.
Simple torch benchmarking tool for fully-connected networks
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
require 'nn' | |
require 'cutorch' | |
require 'cunn' | |
--[[ | |
-- A simple benchmark comparing fully-connected net times on CPU and GPU. | |
-- | |
-- Note that we don't count time it takes to transfer data to the GPU. | |
--]] | |
local cmd = torch.CmdLine() | |
cmd:option('-input_dim', 100) | |
cmd:option('-output_dim', 4) | |
cmd:option('-hidden_dim', 4096) | |
cmd:option('-hidden_layers', 5) | |
cmd:option('-batch_size', 1000) | |
cmd:option('-num_trials', 5) | |
cmd:option('-quiet', false) | |
cmd:option('-gpu', 0) | |
local opt = cmd:parse(arg) | |
cutorch.setDevice(opt.gpu + 1) | |
-- Build the model | |
local model = nn.Sequential() | |
model:add(nn.Linear(opt.input_dim, opt.hidden_dim)) | |
for i = 1, opt.hidden_layers do | |
model:add(nn.Linear(opt.hidden_dim, opt.hidden_dim)) | |
model:add(nn.ReLU(true)) | |
end | |
model:add(nn.Linear(opt.hidden_dim, opt.output_dim)) | |
local crit = nn.MSECriterion() | |
local dtypes = {'torch.FloatTensor', 'torch.CudaTensor'} | |
local mean_times = {} | |
local timer = torch.Timer() | |
for _, dtype in ipairs(dtypes) do | |
print(string.format('Testing dtype %s', dtype)) | |
model:type(dtype) | |
crit:type(dtype) | |
local times = torch.DoubleTensor(opt.num_trials) | |
for t = 1, opt.num_trials do | |
local X = torch.randn(opt.batch_size, opt.input_dim):type(dtype) | |
local y = torch.randn(opt.batch_size, opt.output_dim):type(dtype) | |
cutorch.synchronize() | |
timer:reset() | |
local y_pred = model:forward(X) | |
local loss = crit:forward(y_pred, y) | |
local dy_pred = crit:backward(y_pred, y) | |
model:backward(X, dy_pred) | |
cutorch.synchronize() | |
local time = timer:time().real | |
times[t] = time | |
if not opt.quiet then | |
print(time) | |
end | |
end | |
local mean_time = times:mean() | |
table.insert(mean_times, mean_time) | |
print(string.format('Mean time: %f', mean_time)) | |
end | |
print(string.format('GPU speedup: %f', mean_times[1] / mean_times[2])) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment