Skip to content

Instantly share code, notes, and snippets.

@SamratSahoo
Last active October 31, 2020 23:57
Show Gist options
  • Save SamratSahoo/ebd53a32c43c9cac808f91e941ae7129 to your computer and use it in GitHub Desktop.
Save SamratSahoo/ebd53a32c43c9cac808f91e941ae7129 to your computer and use it in GitHub Desktop.
Init for the autoencoder class
class Autoencoder(nn.Module):
def __init__(self, epochs=100, batchSize=128, learningRate=1e-3):
super(Autoencoder, self).__init__()
# Encoder Network
self.encoder = nn.Sequential(nn.Linear(784, 128),
nn.ReLU(True),
nn.Linear(128, 64),
nn.ReLU(True),
nn.Linear(64, 12),
nn.ReLU(True),
nn.Linear(12, 3))
# Decoder Network
self.decoder = nn.Sequential(nn.Linear(3, 12),
nn.ReLU(True),
nn.Linear(12, 64),
nn.ReLU(True),
nn.Linear(64, 128),
nn.ReLU(True),
nn.Linear(128, 784),
nn.Tanh())
self.epochs = epochs
self.batchSize = batchSize
self.learningRate = learningRate
# Data + Data Loaders
self.imageTransforms = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5])
])
self.data = MNIST('./Data', transform=self.imageTransforms)
self.dataLoader = torch.utils.data.DataLoader(dataset=self.data,
batch_size=self.batchSize,
shuffle=True)
self.optimizer = torch.optim.Adam(self.parameters(), lr=self.learningRate, weight_decay=1e-5)
self.criterion = nn.MSELoss()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment