-
-
Save amaarora/34d094f14c5af2645ac90449f5a43237 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def get_weights(*dims): return nn.Parameter(torch.randn(dims)/dims[0]) | |
def softmax(x): return torch.exp(x)/(torch.exp(x).sum(dim=1)[:,None]) | |
class LogReg(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.l1_w = get_weights(28*28, 10) # Layer 1 weights | |
self.l1_b = get_weights(10) # Layer 1 bias | |
def forward(self, x): | |
x = x.view(x.size(0), -1) | |
x = (x @ self.l1_w) + self.l1_b # Linear Layer | |
x = torch.log(softmax(x)) # Non-linear (LogSoftmax) Layer | |
return x |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment