This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
from torchviz import make_dot | |
device = 'cuda' | |
# sending tensor to device at creation time | |
a = torch.randn(1, requires_grad=True, dtype=torch.float, device=device) | |
plot1 = make_dot(a) | |
# sending tensor to device immediately after creating it |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
square = exponentiation_builder(2) | |
cube = exponentiation_builder(3) | |
fourth_power = exponentiation_builder(4) | |
# and so on and so forth... |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def exponentiation_builder(exponent): | |
def skeleton_exponentiation(x): | |
return x ** exponent | |
return skeleton_exponentiation |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def skeleton_exponentiation(x): | |
return x ** exponent |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def generic_exponentiation(x, exponent): | |
return x ** exponent |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def square(x): | |
return x ** 2 | |
def cube(x): | |
return x ** 3 | |
def fourth_power(x): | |
return x ** 4 | |
# and so on and so forth... |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
decoder_cnn = nn.Sequential( | |
# z_size -> (n_filters*2)*7*7 | |
nn.Linear(z_size, (n_filters*2)*int(img_size/4)**2), | |
# (n_filters*2)*7*7 -> (n_filters*2)@7x7 | |
nn.Unflatten(1, (n_filters*2, int(img_size/4), int(img_size/4))), | |
# (n_filters*2)@7x7 -> (n_filters*2)@7x7 | |
nn.ConvTranspose2d(n_filters*2, n_filters*2, kernel_size=3, stride=1, padding=1, output_padding=0), | |
nn.LeakyReLU(), |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
set_seed(13) | |
z_size = 1 | |
n_filters = 32 | |
in_channels = 1 | |
img_size = 28 | |
input_shape = (in_channels, img_size, img_size) | |
base_model = nn.Sequential( | |
# in_channels@28x28 -> n_filters@28x28 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
model_vae_cnn = AutoEncoder(encoder_var_cnn, decoder_cnn) | |
model_vae_cnn.to(device) | |
loss_fn = nn.MSELoss(reduction='none') | |
optim = torch.optim.Adam(model_vae_cnn.parameters(), 0.0003) | |
num_epochs = 30 | |
train_losses = [] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
model_vae.to(device) | |
loss_fn = nn.MSELoss(reduction='none') | |
optim = torch.optim.Adam(model_vae.parameters(), 0.0003) | |
num_epochs = 30 | |
train_losses = [] | |
reconstruction_loss_factor = 1 |