Created
October 11, 2015 23:52
-
-
Save jcjohnson/04e649e285dbf07690db to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
require 'torch' | |
require 'cutorch' | |
require 'nn' | |
require 'cunn' | |
require 'cudnn' | |
require 'loadcaffe' | |
local cmd = torch.CmdLine() | |
cmd:option('-model', 'alexnet') | |
cmd:option('-backend', 'nn') | |
cmd:option('-gpu', 0) | |
cmd:option('-batch_size', 10) | |
local params = cmd:parse(arg) | |
local model_file, proto_file, size | |
if params.model == 'alexnet' then | |
proto_file = 'models/bvlc_alexnet/deploy.prototxt' | |
model_file = 'models/bvlc_alexnet/bvlc_alexnet.caffemodel' | |
size = 227 | |
elseif params.model == 'caffenet' then | |
proto_file = 'models/bvlc_reference_caffenet/deploy.prototxt' | |
model_file = 'models/bvlc_reference_caffenet/bvlc_reference_caffenet.caffemodel' | |
size = 227 | |
elseif params.model == 'vgg-16' then | |
proto_file = 'models/vgg-16/VGG_ILSVRC_16_layers_deploy.prototxt' | |
model_file = 'models/vgg-16/VGG_ILSVRC_16_layers.caffemodel' | |
size = 224 | |
elseif params.model == 'vgg-19' then | |
proto_file = 'models/vgg-19/VGG_ILSVRC_19_layers_deploy.prototxt' | |
model_file = 'models/vgg-19/VGG_ILSVRC_19_layers.caffemodel' | |
size = 224 | |
else | |
error(string.format('Unrecognized model "%s"', params.model)) | |
end | |
cutorch.setDevice(params.gpu + 1) | |
local cnn = loadcaffe.load(proto_file, model_file, params.backend):cuda() | |
local data = torch.randn(params.batch_size, 3, size, size):cuda() | |
local dout = nil | |
cutorch.synchronize() | |
local timer = torch.Timer() | |
local num_iterations = 50 | |
local forward_times = torch.Tensor(num_iterations) | |
local backward_times = torch.Tensor(num_iterations) | |
for i = 1, num_iterations do | |
cutorch.synchronize() | |
timer:reset() | |
local out = cnn:forward(data) | |
cutorch.synchronize() | |
local forward_time = timer:time().real * 1000 | |
if not dout then | |
dout = torch.randn(#out):cuda() | |
cutorch.synchronize() | |
end | |
timer:reset() | |
local din = cnn:backward(data, dout) | |
cutorch.synchronize() | |
local backward_time = timer:time().real * 1000 | |
local msg = 'Iteration %d / %d, forward %s ms, backward %s ms' | |
print(string.format(msg, i, num_iterations, forward_time, backward_time)) | |
forward_times[i] = forward_time | |
backward_times[i] = backward_time | |
end | |
print(string.format('Mean forward time: %f', forward_times:mean())) | |
print(string.format('Mean backward time: %f', backward_times:mean())) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment