Last active
August 29, 2015 14:15
-
-
Save mfigurnov/2863e30ce3f73a473d32 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
function [ ret ] = my_nnpool( input, pool, varargin ) | |
opts.stride = 1 ; | |
opts.pad = 0 ; | |
opts.method = 'max' ; | |
backMode = numel(varargin) > 0 && ~isstr(varargin{1}) ; | |
if backMode | |
dzdy = varargin{1} ; | |
if isstr(varargin{2}) && strcmpi(varargin{2}, 'verbose') | |
opts = vl_argparse(opts, varargin(3:end)); | |
else | |
opts = vl_argparse(opts, varargin(2:end)); | |
end | |
else | |
if numel(varargin) > 0 && isstr(varargin{1}) && strcmpi(varargin{1}, 'verbose') | |
opts = vl_argparse(opts, varargin(2:end)); | |
else | |
opts = vl_argparse(opts, varargin); | |
end | |
end | |
if length(pool) == 1 | |
windowHeight = pool; | |
windowWidth = pool; | |
elseif length(pool) == 2 | |
windowHeight = pool(1); | |
windowWidth = pool(2); | |
else | |
error('SIZE has neither one nor two elements.'); | |
end | |
height = size(input, 1); | |
width = size(input, 2); | |
D = size(input, 3); | |
N = size(input, 4); | |
if length(opts.stride) == 1 | |
strideY = opts.stride; | |
strideX = opts.stride; | |
elseif length(opts.stride) == 2 | |
strideY = opts.stride(1); | |
strideX = opts.stride(2); | |
else | |
error('STRIDE has neither one nor two elements.'); | |
end | |
if strideX < 1 || strideY < 1 | |
error('At least one element of STRIDE is smaller than one.'); | |
end | |
if length(opts.pad) == 1 | |
padTop = opts.pad; | |
padBottom = opts.pad; | |
padLeft = opts.pad; | |
padRight = opts.pad; | |
elseif length(opts.pad) == 4 | |
padTop = opts.pad(1); | |
padBottom = opts.pad(2); | |
padLeft = opts.pad(3); | |
padRight = opts.pad(4); | |
else | |
error('PAD has neither one nor four elements.'); | |
end | |
if height < windowHeight || width < windowWidth | |
error('Pooling SIZE is larger than the DATA.'); | |
end | |
if windowHeight == 0 || windowWidth == 0 | |
error('A dimension of the pooling SIZE is void.'); | |
end | |
if strideX == 0 || strideY == 0 | |
error('An element of STRIDE is zero.'); | |
end | |
if padLeft < 0 || padRight < 0 || padTop < 0 || padBottom < 0 | |
error('An element of PAD is negative.'); | |
end | |
if padLeft >= windowWidth || padRight >= windowWidth || padTop >= windowHeight || padBottom >= windowHeight | |
error('A padding value is larger or equal than the size of the pooling window.'); | |
end | |
if ~backMode | |
pooledWidth = floor((width + (padLeft + padRight) - windowWidth)/strideX) + 1 ; | |
pooledHeight = floor((height + (padTop + padBottom) - windowHeight)/strideY) + 1 ; | |
ret = zeros(pooledHeight, pooledWidth, D, N, 'single'); | |
if strcmpi(opts.method, 'max') | |
for n = 1:N | |
for d = 1:D | |
for y = 1:pooledHeight | |
for x = 1:pooledWidth | |
x1 = (x-1) * strideX - padLeft + 1; | |
y1 = (y-1) * strideY - padTop + 1; | |
x2 = min(x1 + windowWidth - 1, width); | |
y2 = min(y1 + windowHeight - 1, height); | |
x1 = max(x1, 1); | |
y1 = max(y1, 1); | |
values = input(y1:y2, x1:x2, d, n); | |
bestValue = max(values(:)); | |
ret(y, x, d, n) = bestValue; | |
end | |
end | |
end | |
end | |
elseif strcmpi(opts.method, 'avg') | |
for n = 1:N | |
for d = 1:D | |
for y = 1:pooledHeight | |
for x = 1:pooledWidth | |
x1 = (x-1) * strideX - padLeft + 1; | |
y1 = (y-1) * strideY - padTop + 1; | |
x2 = min(x1 + windowWidth - 1, width); | |
y2 = min(y1 + windowHeight - 1, height); | |
x1 = max(x1, 1); | |
y1 = max(y1, 1); | |
values = input(y1:y2, x1:x2, d, n); | |
avgValue = sum(values(:)) / ((y2 - y1 + 1) * (x2 - x1 + 1)); | |
ret(y, x, d, n) = avgValue; | |
end | |
end | |
end | |
end | |
else | |
error('METHOD is not a supported method.'); | |
end | |
else % backward mode | |
pooledHeight = size(dzdy, 1); | |
pooledWidth = size(dzdy, 2); | |
ret = zeros(height, width, D, N, 'single'); | |
if strcmpi(opts.method, 'max') | |
for n = 1:N | |
for d = 1:D | |
for py = 1:pooledHeight | |
for px = 1:pooledWidth | |
x1 = (px-1) * strideX - padLeft + 1; | |
y1 = (py-1) * strideY - padTop + 1; | |
x2 = min(x1 + windowWidth - 1, width); | |
y2 = min(y1 + windowHeight - 1, height); | |
x1 = max(x1, 1); | |
y1 = max(y1, 1); | |
bestValue = input(y1, x1, d, n); | |
bestIndex = [y1 x1]; | |
for y = y1:y2 | |
for x = x1:x2 | |
value = input(y, x, d, n); | |
if value > bestValue | |
bestValue = value; | |
bestIndex = [y x]; | |
end | |
end | |
end | |
ret(bestIndex(1), bestIndex(2), d, n) = ... | |
ret(bestIndex(1), bestIndex(2), d, n) + dzdy(py, px, d, n); | |
end | |
end | |
end | |
end | |
elseif strcmpi(opts.method, 'avg') | |
for n = 1:N | |
for d = 1:D | |
for py = 1:pooledHeight | |
for px = 1:pooledWidth | |
x1 = (px-1) * strideX - padLeft + 1; | |
y1 = (py-1) * strideY - padTop + 1; | |
x2 = min(x1 + windowWidth - 1, width); | |
y2 = min(y1 + windowHeight - 1, height); | |
x1 = max(x1, 1); | |
y1 = max(y1, 1); | |
avgValue = dzdy(py, px, d, n) / ((y2 - y1 + 1) * (x2 - x1 + 1)); | |
ret(y1:y2, x1:x2, d, n) = ret(y1:y2, x1:x2, d, n) + avgValue; | |
end | |
end | |
end | |
end | |
else | |
error('METHOD is not a supported method.'); | |
end | |
end | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment