Skip to content

Instantly share code, notes, and snippets.

@Swarchal
Last active November 27, 2018 09:44
Show Gist options
  • Save Swarchal/fd07ac732f6f0e770533d13507f448b8 to your computer and use it in GitHub Desktop.
Save Swarchal/fd07ac732f6f0e770533d13507f448b8 to your computer and use it in GitHub Desktop.
"""
docstring
"""
import os
import glob
import itertools
import numpy as np
import skimage.io
import skimage.color
import skimage.transform
import torch
from torch.utils.data import Dataset
class SingleChannelDataset(Dataset):
def __init__(self, data_dir, transforms=None):
self.data_dir = data_dir
self.file_list = get_file_list(data_dir)
self.image_list = self.parse_image_list()
self.transforms = transforms
def __len__(self):
return len(self.image_list)
def __getitem__(self, index):
img_path = self.image_list[index]
img = skimage.io.imread(img_path)
img = crop_center(img, 1700, 1700)
img_resized = skimage.transform.resize(
img, (224, 224), mode="reflect", anti_aliasing=True
)
img_rgb = torch.tensor(np.stack([img_resized]*3)).float()
if self.transforms is not None:
img_rgb = self.transforms(img_rgb)
column = get_column(img_path)
is_bleb = column > 8
# convert single channel image to 3 channels by duplicating the
# channel this is used as the pre-trained models assume 3 channels
return img_rgb, is_bleb
def parse_image_list(self):
"""Just want the first phase channel images"""
return [i for i in self.file_list if get_channel(i) == 1]
class MultiChannelDataset(Dataset):
def __init__(self, data_dir, transforms=None):
self.data_dir = data_dir
self.file_list = get_file_list(data_dir)
self.image_list = self.parse_image_list()
self.transforms = transforms
def __len__(self):
return len(self.image_list)
def __getitem__(self, index):
img_paths = self.image_list[index]
# duplicate second channel to make an RGB image
img_paths.append(img_paths[1])
img = skimage.io.imread_collection(img_paths).concatenate()
img_cropped = crop_center(img, 1700, 1700)
img_resized = skimage.transform.resize(
img, (3, 224, 224), mode="reflect", anti_aliasing=True
)
img_resized = torch.tensor(img_resized).float()
if self.transforms is not None:
img_resized = self.transforms(img_resized)
column = get_column(img_paths[0])
is_bleb = column > 8
# convert single channel image to 3 channels by duplicating the
# channel this is used as the pre-trained models assume 3 channels
return img_resized, is_bleb
def parse_image_list(self):
"""
split image list into sublists containing both channels
"""
sorted_img_list = sorted(self.file_list)
return list(chunks(sorted_img_list, 2))
def chunks(l, n):
# For item i in a range that is a length of l,
for i in range(0, len(l), n):
# Create an index range for l of n items:
yield l[i:i+n]
def get_file_list(data_dir):
"""get image paths from an ImageXpress experiment directory"""
# if data_dir is a string then we have a single data dir
if isinstance(data_dir, str):
image_paths = glob.glob(f"{data_dir}/*.tif")
# if it's a list then we have multiple and should combine the file list
elif isinstance(data_dir, list):
image_paths = []
for i in data_dir:
tmp_img_paths = glob.glob(f"{i}/*.tif")
image_paths.extend(tmp_img_paths)
else:
raise ValueError()
# filter image paths to remove thumbnails
filtered_image_paths = [i for i in image_paths if "thumb" not in i]
return filtered_image_paths
def get_channel(img_path):
"""docstring"""
basename = os.path.basename(img_path)
return int(basename.split("_")[2][1])
def get_column(img_path):
"""docstring"""
basename = os.path.basename(img_path)
return int(basename.split("_")[1][1:])
def crop_center(img, crop_x, crop_y):
y, x = img.shape[1:]
start_x = x//2-(crop_x//2)
start_y = y//2-(crop_y//2)
return img[:, start_y:start_y+crop_y, start_x:start_x+crop_x]
"""
docstring
"""
import argparse
import numpy as np
import torch
import torchvision
from torch.utils.data import DataLoader
from data_utils import SingleChannelDataset, MultiChannelDataset
from tqdm import tqdm
import matplotlib.pyplot as plt
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
def softmax(x):
"""simple softmax function"""
e_x = np.exp(x - np.max(x))
return e_x / e_x.sum()
if __name__ == "__main__":
# command line arguments
parser = argparse.ArgumentParser()
parser.add_argument(
"-d", "--data_dir", nargs="+", required=True,
help="data directories"
)
parser.add_argument("--test_data_dir", required=True)
parser.add_argument("--lr", type=float, default=0.005)
parser.add_argument("--log_location", default="history.json")
parser.add_argument("--epochs", type=int, default=20)
parser.add_argument("--num_workers", type=int, default=8)
parser.add_argument("--batch_size", type=int, default=16)
parser.add_argument("--freeze", type=bool, default=False)
args = parser.parse_args()
# load model
model = torchvision.models.resnet50(pretrained=True)
# optionally freeze model weights
# as the last layer is new it won't be frozen and will be updated
# during training
if args.freeze:
for param in model.parameters():
param.requires_grad = False
# need to strip final layers and replace with two neurons
# to account for the two classes
num_features = model.fc.in_features
model.fc = torch.nn.Linear(num_features, 2)
model = model.to(DEVICE)
# transforms for training data set
transforms = torchvision.transforms.Compose([
torchvision.transforms.ToPILImage(),
torchvision.transforms.RandomVerticalFlip(p=0.5),
torchvision.transforms.RandomHorizontalFlip(p=0.5),
torchvision.transforms.ColorJitter(brightness=0.3),
torchvision.transforms.ToTensor()
])
# data_set
data_set = MultiChannelDataset(data_dir=args.data_dir,
transforms=transforms)
test_data_set = MultiChannelDataset(data_dir=args.test_data_dir)
# data_loader
data_loader = DataLoader(data_set, batch_size=args.batch_size, shuffle=True,
pin_memory=torch.cuda.is_available(),
num_workers=args.num_workers)
test_data_loader = DataLoader(test_data_set, shuffle=False,
pin_memory=torch.cuda.is_available(),
num_workers=args.num_workers)
optimizer = torch.optim.Adam(
filter(lambda p: p.requires_grad, model.parameters()),
lr=args.lr
)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer)
# adjust weights, inverse weighting to class abundance
weights = torch.tensor([0.33, 0.66]).to(DEVICE)
criterion = torch.nn.CrossEntropyLoss(weight=weights)
len_data = len(data_set)
for epoch in range(args.epochs):
print(f"Epoch {epoch+1}/{args.epochs}")
running_loss = 0.0
running_corrects = 0
for batch in tqdm(data_loader):
imgs, labels = batch
imgs = torch.tensor(imgs).to(DEVICE).float()
labels = torch.tensor(labels).to(DEVICE)
#
optimizer.zero_grad()
outputs = model(imgs)
labels = labels.view(-1)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
#
running_loss += loss.item()
_, preds = torch.max(outputs.data, 1)
running_corrects += torch.sum(preds == labels)
epoch_loss = running_loss / len_data
epoch_acc = float(int(running_corrects) / int(len_data))
print(f"epoch acc = {epoch_acc}")
print(f"epoch loss = {epoch_loss}")
print("Testing model...")
model = model.eval()
running_corrects = 0.0
len_data = len(test_data_set)
for i, batch in enumerate(test_data_loader):
imgs, labels = batch
imgs = torch.tensor(imgs).to(DEVICE).float()
labels = torch.tensor(labels).to(DEVICE)
#
optimizer.zero_grad()
outputs = model(imgs)
softmax_output = softmax(outputs.data.cpu().numpy())[0]
labels = labels.view(-1)
_, preds = torch.max(outputs.data, 1)
running_corrects += torch.sum(preds == labels)
imgs = imgs.cpu().numpy()[0,...]
# move colour channels to the end for matplotlib
imgs = np.transpose(imgs, (1, 2, 0))
plt.figure(i)
plt.imshow(imgs, cmap=plt.cm.gray)
plt.title(f"Blob = {softmax_output[0]}\nBleb = {softmax_output[1]}")
plt.show()
test_acc = float(int(running_corrects) / int(len_data))
print(f"Test acc = {test_acc}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment