Created
August 13, 2023 23:17
-
-
Save vhxs/bb74d2ee8b37b550772b8df29c37476e to your computer and use it in GitHub Desktop.
ChatGPT-suggested fine-tuning example
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
import torchvision.models as models | |
import torchvision.transforms as transforms | |
import torchvision.datasets as datasets | |
# Load pre-trained ResNet-18 | |
model = models.resnet18(pretrained=True) | |
# Freeze all layers except the last fully connected layer | |
for param in model.parameters(): | |
param.requires_grad = False | |
model.fc.requires_grad = True | |
# Modify the last fully connected layer for the new task (cats vs. dogs) | |
num_classes = 2 # We have two classes: cats and dogs | |
model.fc = nn.Linear(model.fc.in_features, num_classes) | |
# Define data transforms | |
data_transforms = { | |
'train': transforms.Compose([ | |
transforms.RandomResizedCrop(224), | |
transforms.RandomHorizontalFlip(), | |
transforms.ToTensor(), | |
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), | |
]), | |
'val': transforms.Compose([ | |
transforms.Resize(256), | |
transforms.CenterCrop(224), | |
transforms.ToTensor(), | |
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), | |
]), | |
} | |
# Organize dataset paths | |
data_dir = '/home/vsaraph/Downloads/newdataset' # Replace with your dataset path | |
# Load the new dataset | |
image_datasets = { | |
'train': datasets.ImageFolder(os.path.join(data_dir, 'train'), data_transforms['train']), | |
'val': datasets.ImageFolder(os.path.join(data_dir, 'val'), data_transforms['val']) | |
} | |
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=64, shuffle=True, num_workers=4) | |
for x in ['train', 'val']} | |
# Define loss function and optimizer | |
criterion = nn.CrossEntropyLoss() | |
optimizer = optim.SGD(model.fc.parameters(), lr=0.001, momentum=0.9) | |
# Train the model | |
num_epochs = 5 | |
for epoch in range(num_epochs): | |
i = 0 | |
for inputs, labels in dataloaders['train']: | |
print(f"epoch {epoch}, batch {i}") | |
optimizer.zero_grad() | |
outputs = model(inputs) | |
loss = criterion(outputs, labels) | |
loss.backward() | |
optimizer.step() | |
i += 1 | |
# Evaluate the model on the validation set | |
model.eval() | |
correct = 0 | |
total = 0 | |
with torch.no_grad(): | |
for inputs, labels in dataloaders['val']: | |
outputs = model(inputs) | |
_, predicted = torch.max(outputs, 1) | |
total += labels.size(0) | |
correct += (predicted == labels).sum().item() | |
print('Accuracy on validation set: {:.2f}%'.format(100 * correct / total)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import shutil | |
from sklearn.model_selection import train_test_split | |
# Set the paths to the original cat and dog directories | |
original_cat_dir = '/home/vsaraph/Downloads/PetImages/Cat' # Replace with the actual path | |
original_dog_dir = '/home/vsaraph/Downloads/PetImages/Dog' # Replace with the actual path | |
# Set the path to the new dataset directory | |
new_dataset_dir = '/home/vsaraph/Downloads/newdataset' # Replace with the desired path | |
os.makedirs(new_dataset_dir, exist_ok=True) | |
# Set the paths for train and val subdirectories | |
train_dir = os.path.join(new_dataset_dir, 'train') | |
val_dir = os.path.join(new_dataset_dir, 'val') | |
os.makedirs(train_dir, exist_ok=True) | |
os.makedirs(val_dir, exist_ok=True) | |
# make cat and dog dirs | |
os.makedirs(os.path.join(train_dir, 'cat'), exist_ok=True) | |
os.makedirs(os.path.join(train_dir, 'dog'), exist_ok=True) | |
os.makedirs(os.path.join(val_dir, 'cat'), exist_ok=True) | |
os.makedirs(os.path.join(val_dir, 'dog'), exist_ok=True) | |
# Split data into train and val sets | |
cat_images = os.listdir(original_cat_dir) | |
dog_images = os.listdir(original_dog_dir) | |
cat_train, cat_val = train_test_split(cat_images, test_size=0.2, random_state=42) | |
dog_train, dog_val = train_test_split(dog_images, test_size=0.2, random_state=42) | |
# Move cat images to train and val directories | |
for image in cat_train: | |
src_path = os.path.join(original_cat_dir, image) | |
dst_path = os.path.join(train_dir, 'cat', image) | |
shutil.move(src_path, dst_path) | |
for image in cat_val: | |
src_path = os.path.join(original_cat_dir, image) | |
dst_path = os.path.join(val_dir, 'cat', image) | |
shutil.move(src_path, dst_path) | |
# Move dog images to train and val directories | |
for image in dog_train: | |
src_path = os.path.join(original_dog_dir, image) | |
dst_path = os.path.join(train_dir, 'dog', image) | |
shutil.move(src_path, dst_path) | |
for image in dog_val: | |
src_path = os.path.join(original_dog_dir, image) | |
dst_path = os.path.join(val_dir, 'dog', image) | |
shutil.move(src_path, dst_path) | |
print("Dataset split into train and val directories.") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment