This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
## load the entire dataset | |
x, y = next(iter(train_dataset)) | |
## print one example | |
dim = x.shape[1] | |
print("Dimension of image:", x.shape, "\n", | |
"Dimension of labels", y.shape) | |
plt.imshow(x[160].reshape(1, 3, 224, 224).squeeze().T.numpy()) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
## configure root folder on your gdrive | |
data_dir = 'gdrive/My Drive/DAIR RESOURCES/TF to PT/datasets/hymenoptera_data' | |
## custom transformer to flatten the image tensors | |
class ReshapeTransform: | |
def __init__(self, new_size): | |
self.new_size = new_size | |
def __call__(self, img): | |
result = torch.reshape(img, self.new_size) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
## importing dataset | |
from google.colab import drive | |
drive.mount('gdrive', force_remount=True) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
## configuration to detect cuda or cpu | |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
print (device) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
## Import the usual libraries | |
import torch | |
import torchvision | |
import torch.nn as nn | |
from torchvision import datasets, models, transforms | |
import os | |
import numpy as np | |
import matplotlib.pyplot as plt | |
%matplotlib inline |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
## test the model | |
sample = torch.tensor([10.0], dtype=torch.float) | |
predicted = model(sample) | |
print(predicted.detach().item()) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
## training | |
for i in range(150): | |
model = model.train() | |
## forward | |
output = model(x) | |
loss = criterion(output, y) | |
optimizer.zero_grad() | |
## backward + update model params |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
## loss function | |
criterion = nn.MSELoss() | |
## optimizer algorithm | |
optimizer = torch.optim.SGD(model.parameters(), lr=0.01) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
## Neural network with 1 hidden layer | |
layer1 = nn.Linear(1,1, bias=False) | |
model = nn.Sequential(layer1) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
## our data in tensor form | |
x = torch.tensor([[-1.0], [0.0], [1.0], [2.0], [3.0], [4.0]], dtype=torch.float) | |
y = torch.tensor([[-3.0], [-1.0], [1.0], [3.0], [5.0], [7.0]], dtype=torch.float) |