Skip to content

Instantly share code, notes, and snippets.

@dvgodoy
Created April 30, 2022 09:07
Show Gist options
  • Save dvgodoy/e28238f718c38fb5270c8c4f027c45c8 to your computer and use it in GitHub Desktop.
Save dvgodoy/e28238f718c38fb5270c8c4f027c45c8 to your computer and use it in GitHub Desktop.
class EncoderVar(nn.Module):
def __init__(self, input_shape, z_size, base_model):
super().__init__()
self.z_size = z_size
self.input_shape = input_shape
self.base_model = base_model
output_size = self.get_output_size()
self.lin_mu = nn.Linear(output_size, z_size)
self.lin_var = nn.Linear(output_size, z_size)
def get_output_size(self):
device = next(self.base_model.parameters()).device.type
size = self.base_model(torch.zeros(1, *self.input_shape, device=device)).size(1)
return size
def kl_loss(self):
kl_loss = -0.5*(1 + self.log_var - self.mu**2 - torch.exp(self.log_var))
return kl_loss
def forward(self, x):
# the base model, same as the traditional AE
base_out = self.base_model(x)
# now the encoder produces means (mu) using the lin_mu output layer
# and log variances (log_var) using the lin_var output layer
# we compute the standard deviation (std) from the log variance
self.mu = self.lin_mu(base_out)
self.log_var = self.lin_var(base_out)
std = torch.exp(self.log_var/2)
# that's the internal random input (epsilon)
eps = torch.randn_like(self.mu)
# and that's the z vector
z = self.mu + eps * std
return z
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment