Created
April 24, 2021 15:26
-
-
Save creotiv/31ef2bb0eb650d0e2a83671c5120568b to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
class Conv2d(nn.Module): | |
def __init__(self, inch, outch, kernel, padding=0, stride=1): | |
super(Conv2d, self).__init__() | |
# setting our kernels with random normal | |
self.kernels = torch.randn(outch, inch, kernel, kernel) | |
self.kernel = kernel | |
self.padding = padding | |
self.stride = stride | |
self.inch = inch | |
self.outch = outch | |
def get_size(self, s): | |
# formula for calculating output shape | |
return int(((s - self.kernel + 2 * self.padding) / self.stride) + 1) | |
def forward(self, x): | |
if self.padding > 0: | |
x = F.pad(x,(self.padding,)*4, "constant", 0) | |
b,c,h,w = x.shape | |
out_w = self.get_size(w) | |
out_h = self.get_size(h) | |
out = torch.zeros(b,self.outch,out_h,out_w) | |
for _b in range(b): | |
for _hh,_h in enumerate(range(0,h,self.stride)): | |
if _h+self.kernel >= h: | |
continue | |
for _ww,_w in enumerate(range(0,w,self.stride)): | |
if _w+self.kernel >= w: | |
continue | |
for _cout in range(self.outch): | |
for _cin in range(self.inch): | |
kern = self.kernels[_cout, _cin] | |
inp = x[_b,_cin,_h:_h+self.kernel,_w+self.kernel] | |
out[_b,_cout,_hh,_ww] += torch.sum(inp*kern) | |
# mean by input channels | |
out = out/self.inch | |
return out | |
cour = Conv2d(3,10,3,padding=0,stride=3) | |
inp=torch.randn(2,3,10,10) | |
y = cour(inp) | |
print(y.shape) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment