Skip to content

Instantly share code, notes, and snippets.

@AlessandroMondin
Last active May 10, 2022 15:59
Show Gist options
  • Save AlessandroMondin/6e8e1e7fc7d2b3c61a78355086511148 to your computer and use it in GitHub Desktop.
Save AlessandroMondin/6e8e1e7fc7d2b3c61a78355086511148 to your computer and use it in GitHub Desktop.
if __name__ == "__main__":
logger = logger(__name__)
lib = "pt"
train_set, val_set = load_imagefolder("../workspace_7/GTSRB/Final_Training/Images/", 0.1, lib)
train_class = TrainModel(lib)
epochs = 10
lr = 0.1
# number of classes of the dataset
num_outputs = 43
X, y = next(iter(train_set))
if train_class.library == "tf":
num_inputs = tf.reduce_prod(X.shape[1:], 0)
W = tf.Variable(tf.random.normal(shape=(num_inputs, num_outputs), mean=0, stddev=0.01))
b = tf.Variable(tf.zeros(num_outputs))
else:
num_inputs = torch.prod(torch.tensor(X.shape)[1:], 0).item()
W = torch.normal(mean=0, std=0.01, size=(num_inputs, num_outputs), requires_grad=True)
b = torch.zeros(num_outputs, requires_grad=True)
# training
for epoch in range(1, epochs+1):
logger.info('Epoch {}'.format(epoch))
loss, acc = train_class.train_loop(lr, train_set, W, b)
logger.info('Mean training loss: {:1f}, mean training accuracy {:1f}'.format(loss, acc))
val_acc = train_class.val_loop(val_set, W, b)
logger.info('Mean validation accuracy {:1f}'.format(val_acc))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment