This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
PATH = './alex_net.pth' | |
# setup | |
model_ft = models.alexnet(pretrained=True,) | |
model_ft.classifier[6] = nn.Linear(4096,len(classes)) | |
model_ft = model_ft.to(device) | |
criterion = nn.CrossEntropyLoss() | |
optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9) | |
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
PATH = './densenet169.pth' | |
# setup | |
model_ft = models.densenet169(pretrained=True,) | |
print(model_ft.classifier) | |
model_ft.classifier=nn.Linear(1664,len(classes)) | |
model_ft = model_ft.to(device) | |
criterion = nn.CrossEntropyLoss() | |
optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9) | |
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
PATH = './densenet161.pth' | |
# setup model | |
model_ft = models.densenet161(pretrained=True,) | |
model_ft.classifier=nn.Linear(2208,len(classes)) | |
model_ft = model_ft.to(device) | |
criterion = nn.CrossEntropyLoss() | |
optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9) | |
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
PATH = './resnet18_net.pth' | |
#setup model | |
model_ft = models.resnet18(pretrained=True) | |
num_ftrs = model_ft.fc.in_features | |
model_ft.fc = nn.Linear(num_ftrs, len(classes)) | |
model_ft = model_ft.to(device) | |
criterion = nn.CrossEntropyLoss() | |
# Observe that all parameters are being optimized | |
optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from collections import Counter | |
train_classes = [dataset.targets[i] for i in train_data.indices] | |
print("train:",Counter(train_classes)) # if doesn' work: Counter(i.item() for i in train_classes) | |
test_classes = [dataset.targets[i] for i in test_data.indices] | |
print("Test:",Counter(test_classes)) # if doesn' work: Counter(i.item() for i in train_classes) | |
print("Total:",dict(Counter(test_data.dataset.targets))) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
num_epochs=30 | |
def train_model(model, criterion, optimizer, scheduler, num_epochs=25): | |
since = time.time() | |
best_model_wts = copy.deepcopy(model.state_dict()) | |
best_acc = 0.0 | |
for epoch in range(num_epochs): |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Looping through it, get a batch on each loop | |
for images, labels in train_loader: | |
pass | |
# Get one batch | |
images, labels = next(iter(train_loader)) | |
indx=10 | |
plt.imshow(images[indx].reshape(64,64)) | |
plt.title(label_map[int(labels[indx].numpy())]) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
root = 'data/' | |
data_transform = transforms.Compose([transforms.Grayscale(num_output_channels=1), | |
transforms.Resize((64,64)), | |
transforms.ToTensor()]) | |
dataset = ImageFolder(root, transform=data_transform) | |
print(dataset.classes) | |
print(dataset.class_to_idx) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torchvision | |
from torchvision import datasets, models, transforms | |
from torchvision.datasets import ImageFolder | |
from torch.utils.data import DataLoader, random_split | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch.optim as optim | |
from torch.optim import lr_scheduler |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
! python yolov5/detect.py --source data/test/ --weights yolov5/runs/train/RoadTrainModel4/weights/best.pt --conf 0.25 --name RoadTestModel |