Created
October 11, 2014 20:52
-
-
Save soumith/1f7645f14738d39be2b5 to your computer and use it in GitHub Desktop.
CuDNN SpatialMaxPooling bug
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
require 'cudnn' | |
require 'cunn' | |
local cudnntest = {} | |
local precision_forward = 1e-4 | |
local precision_backward = 1e-2 | |
local precision_jac = 1e-3 | |
local nloop = 1 | |
local times = {} | |
function cudnntest.SpatialMaxPooling() | |
local bs = 10 | |
local from = 34 | |
local ki = 4 | |
local kj = 3 | |
local si = 4 | |
local sj = 3 | |
local outi = 62 | |
local outj = 90 | |
local ini = (outi-1)*si+ki | |
local inj = (outj-1)*sj+kj | |
local input = torch.randn(bs,from,inj,ini):cuda() | |
local gradOutput = torch.randn(bs,from,outj,outi):cuda() | |
local sconv = nn.SpatialMaxPooling(ki,kj,si,sj):cuda() | |
local groundtruth = sconv:forward(input) | |
local groundgrad = sconv:backward(input, gradOutput) | |
cutorch.synchronize() | |
local gconv = cudnn.SpatialMaxPooling(ki,kj,si,sj):cuda() | |
local rescuda = gconv:forward(input) | |
-- serialize and deserialize | |
local rescuda = gconv:forward(input) | |
local resgrad = gconv:backward(input, gradOutput) | |
cutorch.synchronize() | |
local error = rescuda:float() - groundtruth:float() | |
mytester:assertlt(error:abs():max(), precision_forward, 'error on state (forward) ') | |
error = resgrad:float() - groundgrad:float() | |
mytester:assertlt(error:abs():max(), precision_backward, 'error on state (backward) ') | |
end | |
torch.setdefaulttensortype('torch.FloatTensor') | |
math.randomseed(os.time()) | |
mytester = torch.Tester() | |
mytester:add(cudnntest) | |
torch.manualSeed(10) | |
print(i) | |
mytester:run(tests) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment