Skip to content

Instantly share code, notes, and snippets.

@MartinWeiss12
Last active October 23, 2024 17:38
Show Gist options
  • Save MartinWeiss12/2300d66f859af956da2b9b6d0ac503aa to your computer and use it in GitHub Desktop.
Save MartinWeiss12/2300d66f859af956da2b9b6d0ac503aa to your computer and use it in GitHub Desktop.
Training Loop
model = CNN(
num_classes=len(class_names)
)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
learning_rate = 0.001
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=3, factor=0.3)
patience = 10
best_val_loss = float('inf')
patience_counter = 0
num_epochs = 100
metrics = {
'epoch': [], 'train_loss': [], 'train_accuracy': [], 'train_precision': [],
'train_recall': [], 'train_f1_score': [], 'test_loss': [], 'test_accuracy': [],
'test_precision': [], 'test_recall': [], 'test_f1_score': [], 'auc_roc': [],
'training_time': []
}
for epoch in range(num_epochs):
model.train()
running_loss = 0.0
correct = 0
total = 0
all_train_labels = []
all_train_predictions = []
start_time = time.time()
train_loader = dataloaders['train']
pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}')
for inputs, labels in pbar:
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item() * inputs.size(0)
_, predicted = torch.max(outputs, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
all_train_labels.extend(labels.cpu().numpy())
all_train_predictions.extend(predicted.cpu().numpy())
avg_loss = running_loss / total
accuracy = correct / total
train_precision = precision_score(all_train_labels, all_train_predictions, average='weighted', zero_division=0)
pbar.set_description(f'Epoch {epoch+1}, Loss: {avg_loss:.4f}, Accuracy: {accuracy:.4f}, Precision: {train_precision:.4f}')
epoch_loss = running_loss / len(train_loader.dataset)
epoch_acc = correct / total
train_precision = precision_score(all_train_labels, all_train_predictions, average='weighted', zero_division=0)
train_recall = recall_score(all_train_labels, all_train_predictions, average='weighted', zero_division=0)
train_f1 = f1_score(all_train_labels, all_train_predictions, average='weighted', zero_division=0)
model.eval()
test_loader = dataloaders['test']
test_correct = 0
test_total = 0
test_loss = 0.0
all_test_labels = []
all_test_predictions = []
all_test_outputs = []
pbar_test = tqdm(test_loader, desc='Testing')
with torch.no_grad():
for inputs, labels in pbar_test:
inputs, labels = inputs.to(device), labels.to(device)
outputs = model(inputs)
loss = criterion(outputs, labels)
test_loss += loss.item() * inputs.size(0)
_, predicted = torch.max(outputs, 1)
test_total += labels.size(0)
test_correct += (predicted == labels).sum().item()
all_test_labels.extend(labels.cpu().numpy())
all_test_predictions.extend(predicted.cpu().numpy())
all_test_outputs.extend(torch.softmax(outputs, dim=1).cpu().numpy())
current_test_loss = test_loss / test_total
current_test_acc = test_correct / test_total
current_test_precision = precision_score(all_test_labels, all_test_predictions, average='weighted', zero_division=0)
pbar_test.set_description(f'Testing - Loss: {current_test_loss:.4f}, Accuracy: {current_test_acc:.4f}, Precision: {current_test_precision:.4f}')
test_loss /= len(test_loader.dataset)
test_accuracy = test_correct / test_total
test_precision = precision_score(all_test_labels, all_test_predictions, average='weighted', zero_division=0)
test_recall = recall_score(all_test_labels, all_test_predictions, average='weighted', zero_division=0)
test_f1 = f1_score(all_test_labels, all_test_predictions, average='weighted', zero_division=0)
auc_roc = roc_auc_score(all_test_labels, all_test_outputs, multi_class='ovr', average='weighted') if len(set(all_test_labels)) > 1 else 0.0
training_time = time.time() - start_time
metrics['epoch'].append(epoch + 1)
metrics['train_loss'].append(epoch_loss)
metrics['train_accuracy'].append(epoch_acc)
metrics['train_precision'].append(train_precision)
metrics['train_recall'].append(train_recall)
metrics['train_f1_score'].append(train_f1)
metrics['test_loss'].append(test_loss)
metrics['test_accuracy'].append(test_accuracy)
metrics['test_precision'].append(test_precision)
metrics['test_recall'].append(test_recall)
metrics['test_f1_score'].append(test_f1)
metrics['auc_roc'].append(auc_roc)
metrics['training_time'].append(training_time)
scheduler.step(test_loss)
if test_loss < best_val_loss:
best_val_loss = test_loss
patience_counter = 0
else:
patience_counter += 1
if patience_counter >= patience:
print(f'Early stopping at epoch {epoch + 1}')
break
v = 23
metrics_df = pd.DataFrame(metrics)
metrics_df.to_csv(f'training-metrics-v{v}.csv', index=False)
print('Metrics saved.')
model_name = f'skin-cancer-recognition-v{v}.pth'
torch.save(model, model_name)
print(f'Model saved as {model_name}')
print('Training complete.')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment