Last active
May 10, 2022 15:59
-
-
Save AlessandroMondin/6e8e1e7fc7d2b3c61a78355086511148 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
if __name__ == "__main__": | |
logger = logger(__name__) | |
lib = "pt" | |
train_set, val_set = load_imagefolder("../workspace_7/GTSRB/Final_Training/Images/", 0.1, lib) | |
train_class = TrainModel(lib) | |
epochs = 10 | |
lr = 0.1 | |
# number of classes of the dataset | |
num_outputs = 43 | |
X, y = next(iter(train_set)) | |
if train_class.library == "tf": | |
num_inputs = tf.reduce_prod(X.shape[1:], 0) | |
W = tf.Variable(tf.random.normal(shape=(num_inputs, num_outputs), mean=0, stddev=0.01)) | |
b = tf.Variable(tf.zeros(num_outputs)) | |
else: | |
num_inputs = torch.prod(torch.tensor(X.shape)[1:], 0).item() | |
W = torch.normal(mean=0, std=0.01, size=(num_inputs, num_outputs), requires_grad=True) | |
b = torch.zeros(num_outputs, requires_grad=True) | |
# training | |
for epoch in range(1, epochs+1): | |
logger.info('Epoch {}'.format(epoch)) | |
loss, acc = train_class.train_loop(lr, train_set, W, b) | |
logger.info('Mean training loss: {:1f}, mean training accuracy {:1f}'.format(loss, acc)) | |
val_acc = train_class.val_loop(val_set, W, b) | |
logger.info('Mean validation accuracy {:1f}'.format(val_acc)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment