-
-
Save jagandecapri/d98b98565f28855cfac8ecb7e90e3fca to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
torch.manual_seed(42) | |
x_tensor = torch.from_numpy(x).float() | |
y_tensor = torch.from_numpy(y).float() | |
# Builds dataset with ALL data | |
dataset = TensorDataset(x_tensor, y_tensor) | |
# Splits randomly into train and validation datasets | |
train_dataset, val_dataset = random_split(dataset, [80, 20]) | |
# Builds a loader for each dataset to perform mini-batch gradient descent | |
train_loader = DataLoader(dataset=train_dataset, batch_size=16) | |
val_loader = DataLoader(dataset=val_dataset, batch_size=20) | |
# Builds a simple sequential model | |
model = nn.Sequential(nn.Linear(1, 1)).to(device) | |
print(model.state_dict()) | |
# Sets hyper-parameters | |
lr = 1e-1 | |
n_epochs = 150 | |
# Defines loss function and optimizer | |
loss_fn = nn.MSELoss(reduction='mean') | |
optimizer = optim.SGD(model.parameters(), lr=lr) | |
losses = [] | |
val_losses = [] | |
# Creates function to perform train step from model, loss and optimizer | |
train_step = make_train_step(model, loss_fn, optimizer) | |
# Training loop | |
for epoch in range(n_epochs): | |
# Uses loader to fetch one mini-batch for training | |
for x_batch, y_batch in train_loader: | |
# NOW, sends the mini-batch data to the device | |
# so it matches location of the MODEL | |
x_batch = x_batch.to(device) | |
y_batch = y_batch.to(device) | |
# One stpe of training | |
loss = train_step(x_batch, y_batch) | |
losses.append(loss) | |
# After finishing training steps for all mini-batches, | |
# it is time for evaluation! | |
# We tell PyTorch to NOT use autograd... | |
# Do you remember why? | |
with torch.no_grad(): | |
# Uses loader to fetch one mini-batch for validation | |
for x_val, y_val in val_loader: | |
# Again, sends data to same device as model | |
x_val = x_val.to(device) | |
y_val = y_val.to(device) | |
# What is that?! | |
model.eval() | |
# Makes predictions | |
yhat = model(x_val) | |
# Computes validation loss | |
val_loss = loss_fn(y_val, yhat) | |
val_losses.append(val_loss.item()) | |
print(model.state_dict()) | |
print(np.mean(losses)) | |
print(np.mean(val_losses)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment