Skip to content

Instantly share code, notes, and snippets.

@jagandecapri
Forked from dvgodoy/torch101_full.py
Created February 21, 2022 12:06
Show Gist options
  • Save jagandecapri/d98b98565f28855cfac8ecb7e90e3fca to your computer and use it in GitHub Desktop.
Save jagandecapri/d98b98565f28855cfac8ecb7e90e3fca to your computer and use it in GitHub Desktop.
torch.manual_seed(42)
x_tensor = torch.from_numpy(x).float()
y_tensor = torch.from_numpy(y).float()
# Builds dataset with ALL data
dataset = TensorDataset(x_tensor, y_tensor)
# Splits randomly into train and validation datasets
train_dataset, val_dataset = random_split(dataset, [80, 20])
# Builds a loader for each dataset to perform mini-batch gradient descent
train_loader = DataLoader(dataset=train_dataset, batch_size=16)
val_loader = DataLoader(dataset=val_dataset, batch_size=20)
# Builds a simple sequential model
model = nn.Sequential(nn.Linear(1, 1)).to(device)
print(model.state_dict())
# Sets hyper-parameters
lr = 1e-1
n_epochs = 150
# Defines loss function and optimizer
loss_fn = nn.MSELoss(reduction='mean')
optimizer = optim.SGD(model.parameters(), lr=lr)
losses = []
val_losses = []
# Creates function to perform train step from model, loss and optimizer
train_step = make_train_step(model, loss_fn, optimizer)
# Training loop
for epoch in range(n_epochs):
# Uses loader to fetch one mini-batch for training
for x_batch, y_batch in train_loader:
# NOW, sends the mini-batch data to the device
# so it matches location of the MODEL
x_batch = x_batch.to(device)
y_batch = y_batch.to(device)
# One stpe of training
loss = train_step(x_batch, y_batch)
losses.append(loss)
# After finishing training steps for all mini-batches,
# it is time for evaluation!
# We tell PyTorch to NOT use autograd...
# Do you remember why?
with torch.no_grad():
# Uses loader to fetch one mini-batch for validation
for x_val, y_val in val_loader:
# Again, sends data to same device as model
x_val = x_val.to(device)
y_val = y_val.to(device)
# What is that?!
model.eval()
# Makes predictions
yhat = model(x_val)
# Computes validation loss
val_loss = loss_fn(y_val, yhat)
val_losses.append(val_loss.item())
print(model.state_dict())
print(np.mean(losses))
print(np.mean(val_losses))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment