Skip to content

Instantly share code, notes, and snippets.

@creotiv
Created April 24, 2021 15:26
Show Gist options
  • Save creotiv/31ef2bb0eb650d0e2a83671c5120568b to your computer and use it in GitHub Desktop.
Save creotiv/31ef2bb0eb650d0e2a83671c5120568b to your computer and use it in GitHub Desktop.
import torch
import torch.nn as nn
import torch.nn.functional as F
class Conv2d(nn.Module):
def __init__(self, inch, outch, kernel, padding=0, stride=1):
super(Conv2d, self).__init__()
# setting our kernels with random normal
self.kernels = torch.randn(outch, inch, kernel, kernel)
self.kernel = kernel
self.padding = padding
self.stride = stride
self.inch = inch
self.outch = outch
def get_size(self, s):
# formula for calculating output shape
return int(((s - self.kernel + 2 * self.padding) / self.stride) + 1)
def forward(self, x):
if self.padding > 0:
x = F.pad(x,(self.padding,)*4, "constant", 0)
b,c,h,w = x.shape
out_w = self.get_size(w)
out_h = self.get_size(h)
out = torch.zeros(b,self.outch,out_h,out_w)
for _b in range(b):
for _hh,_h in enumerate(range(0,h,self.stride)):
if _h+self.kernel >= h:
continue
for _ww,_w in enumerate(range(0,w,self.stride)):
if _w+self.kernel >= w:
continue
for _cout in range(self.outch):
for _cin in range(self.inch):
kern = self.kernels[_cout, _cin]
inp = x[_b,_cin,_h:_h+self.kernel,_w+self.kernel]
out[_b,_cout,_hh,_ww] += torch.sum(inp*kern)
# mean by input channels
out = out/self.inch
return out
cour = Conv2d(3,10,3,padding=0,stride=3)
inp=torch.randn(2,3,10,10)
y = cour(inp)
print(y.shape)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment