Skip to content

Instantly share code, notes, and snippets.

@creotiv
Created January 31, 2020 15:13
Show Gist options
  • Save creotiv/a0a9beb17692517b2cd44e5c9296a33d to your computer and use it in GitHub Desktop.
Save creotiv/a0a9beb17692517b2cd44e5c9296a33d to your computer and use it in GitHub Desktop.
from torchvision.datasets import MNIST
from torchvision import transforms
from torch.utils.data import DataLoader
class Model(Layer):
def __init__(self, lr=0.00001):
self.lr = lr
self.layers = [
Linear(784,100, lr=self.lr),
Relu(),
Linear(100,200, lr=self.lr),
Relu(),
Linear(200,10, lr=self.lr)
]
def forward(self,x):
for l in self.layers:
x = l(x)
return x
def backward(self, grad):
for l in self.layers[::-1]:
grad = l.backward(grad)
return grad
simple = transforms.Compose([
transforms.ToTensor(), # converts to [0,1] interval
])
ds = MNIST('./mnist', download=True, transform=simple)
ld = DataLoader(ds, batch_size=2, pin_memory=True, drop_last=True)
mm = Model()
loss = SoftmaxCrossentropyWithLogits()
_loss_avg = 0
for e in range(5):
for i, (img, label) in enumerate(ld):
x = img.view(2,-1).numpy()
res = mm(x)
_loss = loss(res, label.numpy())
_loss_avg += _loss.mean() # running loss mean
grad = loss.backward(1)
mm.backward(grad)
if i % 100 == 0:
print(_loss_avg/100)
_loss_avg = 0
print('---------')
for i in range(10):
img, target = ds[i]
plt.imshow(img[0])
plt.show()
x = img.view(1,-1).numpy()
res = mm(x)[0]
pred = np.argmax(res)
print(f'target: {target} predicted: {pred}' )
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment