This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def cross_entropy(self, scaled_logits, one_hot): | |
if self.library == "tf": | |
masked_logits = tf.boolean_mask(scaled_logits, one_hot) | |
ce = -tf.math.log(masked_logits) | |
else: | |
masked_logits = torch.masked_select(scaled_logits, one_hot) | |
ce = -torch.log(masked_logits) | |
return ce |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def logistic_regression(self, X, W, b): | |
if self.library == "tf": | |
# flatten_X has shape (batch_size, W*H*Channels) --> in our case (64, 32*32*3) | |
flatten_X = tf.reshape(X, (-1, W.shape[0])) | |
out = self.softmax(tf.matmul(flatten_X, W) + b) | |
else: | |
flatten_X = X.reshape((-1, W.shape[0])) | |
out = self.softmax(torch.matmul(flatten_X,W) + b) | |
return out |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def train_loop(self, lr, train_data, W, b): | |
losses = [] | |
accuracies = [] | |
if self.library == "tf": | |
for X, Y in train_data: | |
with tf.GradientTape() as tape: | |
X = X / 255.0 | |
# y_hat has shape (batch_size, num_of_classes) | |
y_hat = self.logistic_regression(X, W, b) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from tensorflow.keras.preprocessing import image_dataset_from_directory | |
from torch.utils.data import DataLoader | |
from torchvision import transforms, datasets | |
from torch.utils.data.sampler import SubsetRandomSampler | |
# this function takes 3 arguments: the directory of your image folder, the % of data used for validation, | |
# the library that can be either tensorflow or pytorch | |
def load_imagefolder(dir, val_size=0.1, library = "tf"): |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
if __name__ == "__main__": | |
logger = logger(__name__) | |
lib = "pt" | |
train_set, val_set = load_imagefolder("../workspace_7/GTSRB/Final_Training/Images/", 0.1, lib) | |
train_class = TrainModel(lib) | |
epochs = 10 | |
lr = 0.1 | |
# number of classes of the dataset | |
num_outputs = 43 | |
NewerOlder