Last active November 27, 2018 09:44
import os
import glob
import itertools
import numpy as np
import skimage.color
import skimage.transform
import torch
from import Dataset
class SingleChannelDataset(Dataset):
def __init__(self, data_dir, transforms=None):
self.data_dir = data_dir
self.file_list = get_file_list(data_dir)
self.image_list = self.parse_image_list()
self.transforms = transforms
def __len__(self):
return len(self.image_list)
def __getitem__(self, index):
img_path = self.image_list[index]
img =
img = crop_center(img, 1700, 1700)
img_resized = skimage.transform.resize(
img, (224, 224), mode="reflect", anti_aliasing=True
img_rgb = torch.tensor(np.stack([img_resized]*3)).float()
if self.transforms is not None:
img_rgb = self.transforms(img_rgb)
column = get_column(img_path)
is_bleb = column > 8
# convert single channel image to 3 channels by duplicating the
# channel this is used as the pre-trained models assume 3 channels
return img_rgb, is_bleb
def parse_image_list(self):
"""Just want the first phase channel images"""
return [i for i in self.file_list if get_channel(i) == 1]
class MultiChannelDataset(Dataset):
def __init__(self, data_dir, transforms=None):
self.data_dir = data_dir
self.file_list = get_file_list(data_dir)
self.image_list = self.parse_image_list()
self.transforms = transforms
def __len__(self):
return len(self.image_list)
def __getitem__(self, index):
img_paths = self.image_list[index]
# duplicate second channel to make an RGB image
img =
img_cropped = crop_center(img, 1700, 1700)
img_resized = skimage.transform.resize(
img, (3, 224, 224), mode="reflect", anti_aliasing=True
img_resized = torch.tensor(img_resized).float()
if self.transforms is not None:
img_resized = self.transforms(img_resized)
column = get_column(img_paths[0])
is_bleb = column > 8
# convert single channel image to 3 channels by duplicating the
# channel this is used as the pre-trained models assume 3 channels
return img_resized, is_bleb
def parse_image_list(self):
split image list into sublists containing both channels
sorted_img_list = sorted(self.file_list)
return list(chunks(sorted_img_list, 2))
def chunks(l, n):
# For item i in a range that is a length of l,
for i in range(0, len(l), n):
# Create an index range for l of n items:
yield l[i:i+n]
def get_file_list(data_dir):
"""get image paths from an ImageXpress experiment directory"""
# if data_dir is a string then we have a single data dir
if isinstance(data_dir, str):
image_paths = glob.glob(f"{data_dir}/*.tif")
# if it's a list then we have multiple and should combine the file list
elif isinstance(data_dir, list):
image_paths = []
for i in data_dir:
tmp_img_paths = glob.glob(f"{i}/*.tif")
raise ValueError()
# filter image paths to remove thumbnails
filtered_image_paths = [i for i in image_paths if "thumb" not in i]
return filtered_image_paths
def get_channel(img_path):
basename = os.path.basename(img_path)
return int(basename.split("_")[2][1])
def get_column(img_path):
basename = os.path.basename(img_path)
return int(basename.split("_")[1][1:])
def crop_center(img, crop_x, crop_y):
y, x = img.shape[1:]
start_x = x//2-(crop_x//2)
start_y = y//2-(crop_y//2)
return img[:, start_y:start_y+crop_y, start_x:start_x+crop_x]
import argparse
import numpy as np
import torch
import torchvision
from import DataLoader
from data_utils import SingleChannelDataset, MultiChannelDataset
from tqdm import tqdm
import matplotlib.pyplot as plt
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
def softmax(x):
"""simple softmax function"""
e_x = np.exp(x - np.max(x))
return e_x / e_x.sum()
if __name__ == "__main__":
# command line arguments
parser = argparse.ArgumentParser()
"-d", "--data_dir", nargs="+", required=True,
help="data directories"
parser.add_argument("--test_data_dir", required=True)
parser.add_argument("--lr", type=float, default=0.005)
parser.add_argument("--log_location", default="history.json")
parser.add_argument("--epochs", type=int, default=20)
parser.add_argument("--num_workers", type=int, default=8)
parser.add_argument("--batch_size", type=int, default=16)
parser.add_argument("--freeze", type=bool, default=False)
args = parser.parse_args()
# load model
model = torchvision.models.resnet50(pretrained=True)
# optionally freeze model weights
# as the last layer is new it won't be frozen and will be updated
# during training
if args.freeze:
for param in model.parameters():
param.requires_grad = False
# need to strip final layers and replace with two neurons
# to account for the two classes
num_features = model.fc.in_features
model.fc = torch.nn.Linear(num_features, 2)
model =
# transforms for training data set
transforms = torchvision.transforms.Compose([
# data_set
data_set = MultiChannelDataset(data_dir=args.data_dir,
test_data_set = MultiChannelDataset(data_dir=args.test_data_dir)
# data_loader
data_loader = DataLoader(data_set, batch_size=args.batch_size, shuffle=True,
test_data_loader = DataLoader(test_data_set, shuffle=False,
optimizer = torch.optim.Adam(
filter(lambda p: p.requires_grad, model.parameters()),
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer)
# adjust weights, inverse weighting to class abundance
weights = torch.tensor([0.33, 0.66]).to(DEVICE)
criterion = torch.nn.CrossEntropyLoss(weight=weights)
len_data = len(data_set)
for epoch in range(args.epochs):
print(f"Epoch {epoch+1}/{args.epochs}")
running_loss = 0.0
running_corrects = 0
for batch in tqdm(data_loader):
imgs, labels = batch
imgs = torch.tensor(imgs).to(DEVICE).float()
labels = torch.tensor(labels).to(DEVICE)
outputs = model(imgs)
labels = labels.view(-1)
loss = criterion(outputs, labels)
running_loss += loss.item()
_, preds = torch.max(, 1)
running_corrects += torch.sum(preds == labels)
epoch_loss = running_loss / len_data
epoch_acc = float(int(running_corrects) / int(len_data))
print(f"epoch acc = {epoch_acc}")
print(f"epoch loss = {epoch_loss}")
print("Testing model...")
model = model.eval()
running_corrects = 0.0
len_data = len(test_data_set)
for i, batch in enumerate(test_data_loader):
imgs, labels = batch
imgs = torch.tensor(imgs).to(DEVICE).float()
labels = torch.tensor(labels).to(DEVICE)
outputs = model(imgs)
softmax_output = softmax([0]
labels = labels.view(-1)
_, preds = torch.max(, 1)
running_corrects += torch.sum(preds == labels)
imgs = imgs.cpu().numpy()[0,...]
# move colour channels to the end for matplotlib
imgs = np.transpose(imgs, (1, 2, 0))
plt.title(f"Blob = {softmax_output[0]}\nBleb = {softmax_output[1]}")
test_acc = float(int(running_corrects) / int(len_data))
print(f"Test acc = {test_acc}")
