Last active
November 27, 2018 09:44
-
-
Save Swarchal/fd07ac732f6f0e770533d13507f448b8 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
docstring | |
""" | |
import os | |
import glob | |
import itertools | |
import numpy as np | |
import skimage.io | |
import skimage.color | |
import skimage.transform | |
import torch | |
from torch.utils.data import Dataset | |
class SingleChannelDataset(Dataset): | |
def __init__(self, data_dir, transforms=None): | |
self.data_dir = data_dir | |
self.file_list = get_file_list(data_dir) | |
self.image_list = self.parse_image_list() | |
self.transforms = transforms | |
def __len__(self): | |
return len(self.image_list) | |
def __getitem__(self, index): | |
img_path = self.image_list[index] | |
img = skimage.io.imread(img_path) | |
img = crop_center(img, 1700, 1700) | |
img_resized = skimage.transform.resize( | |
img, (224, 224), mode="reflect", anti_aliasing=True | |
) | |
img_rgb = torch.tensor(np.stack([img_resized]*3)).float() | |
if self.transforms is not None: | |
img_rgb = self.transforms(img_rgb) | |
column = get_column(img_path) | |
is_bleb = column > 8 | |
# convert single channel image to 3 channels by duplicating the | |
# channel this is used as the pre-trained models assume 3 channels | |
return img_rgb, is_bleb | |
def parse_image_list(self): | |
"""Just want the first phase channel images""" | |
return [i for i in self.file_list if get_channel(i) == 1] | |
class MultiChannelDataset(Dataset): | |
def __init__(self, data_dir, transforms=None): | |
self.data_dir = data_dir | |
self.file_list = get_file_list(data_dir) | |
self.image_list = self.parse_image_list() | |
self.transforms = transforms | |
def __len__(self): | |
return len(self.image_list) | |
def __getitem__(self, index): | |
img_paths = self.image_list[index] | |
# duplicate second channel to make an RGB image | |
img_paths.append(img_paths[1]) | |
img = skimage.io.imread_collection(img_paths).concatenate() | |
img_cropped = crop_center(img, 1700, 1700) | |
img_resized = skimage.transform.resize( | |
img, (3, 224, 224), mode="reflect", anti_aliasing=True | |
) | |
img_resized = torch.tensor(img_resized).float() | |
if self.transforms is not None: | |
img_resized = self.transforms(img_resized) | |
column = get_column(img_paths[0]) | |
is_bleb = column > 8 | |
# convert single channel image to 3 channels by duplicating the | |
# channel this is used as the pre-trained models assume 3 channels | |
return img_resized, is_bleb | |
def parse_image_list(self): | |
""" | |
split image list into sublists containing both channels | |
""" | |
sorted_img_list = sorted(self.file_list) | |
return list(chunks(sorted_img_list, 2)) | |
def chunks(l, n): | |
# For item i in a range that is a length of l, | |
for i in range(0, len(l), n): | |
# Create an index range for l of n items: | |
yield l[i:i+n] | |
def get_file_list(data_dir): | |
"""get image paths from an ImageXpress experiment directory""" | |
# if data_dir is a string then we have a single data dir | |
if isinstance(data_dir, str): | |
image_paths = glob.glob(f"{data_dir}/*.tif") | |
# if it's a list then we have multiple and should combine the file list | |
elif isinstance(data_dir, list): | |
image_paths = [] | |
for i in data_dir: | |
tmp_img_paths = glob.glob(f"{i}/*.tif") | |
image_paths.extend(tmp_img_paths) | |
else: | |
raise ValueError() | |
# filter image paths to remove thumbnails | |
filtered_image_paths = [i for i in image_paths if "thumb" not in i] | |
return filtered_image_paths | |
def get_channel(img_path): | |
"""docstring""" | |
basename = os.path.basename(img_path) | |
return int(basename.split("_")[2][1]) | |
def get_column(img_path): | |
"""docstring""" | |
basename = os.path.basename(img_path) | |
return int(basename.split("_")[1][1:]) | |
def crop_center(img, crop_x, crop_y): | |
y, x = img.shape[1:] | |
start_x = x//2-(crop_x//2) | |
start_y = y//2-(crop_y//2) | |
return img[:, start_y:start_y+crop_y, start_x:start_x+crop_x] | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
docstring | |
""" | |
import argparse | |
import numpy as np | |
import torch | |
import torchvision | |
from torch.utils.data import DataLoader | |
from data_utils import SingleChannelDataset, MultiChannelDataset | |
from tqdm import tqdm | |
import matplotlib.pyplot as plt | |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
def softmax(x): | |
"""simple softmax function""" | |
e_x = np.exp(x - np.max(x)) | |
return e_x / e_x.sum() | |
if __name__ == "__main__": | |
# command line arguments | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"-d", "--data_dir", nargs="+", required=True, | |
help="data directories" | |
) | |
parser.add_argument("--test_data_dir", required=True) | |
parser.add_argument("--lr", type=float, default=0.005) | |
parser.add_argument("--log_location", default="history.json") | |
parser.add_argument("--epochs", type=int, default=20) | |
parser.add_argument("--num_workers", type=int, default=8) | |
parser.add_argument("--batch_size", type=int, default=16) | |
parser.add_argument("--freeze", type=bool, default=False) | |
args = parser.parse_args() | |
# load model | |
model = torchvision.models.resnet50(pretrained=True) | |
# optionally freeze model weights | |
# as the last layer is new it won't be frozen and will be updated | |
# during training | |
if args.freeze: | |
for param in model.parameters(): | |
param.requires_grad = False | |
# need to strip final layers and replace with two neurons | |
# to account for the two classes | |
num_features = model.fc.in_features | |
model.fc = torch.nn.Linear(num_features, 2) | |
model = model.to(DEVICE) | |
# transforms for training data set | |
transforms = torchvision.transforms.Compose([ | |
torchvision.transforms.ToPILImage(), | |
torchvision.transforms.RandomVerticalFlip(p=0.5), | |
torchvision.transforms.RandomHorizontalFlip(p=0.5), | |
torchvision.transforms.ColorJitter(brightness=0.3), | |
torchvision.transforms.ToTensor() | |
]) | |
# data_set | |
data_set = MultiChannelDataset(data_dir=args.data_dir, | |
transforms=transforms) | |
test_data_set = MultiChannelDataset(data_dir=args.test_data_dir) | |
# data_loader | |
data_loader = DataLoader(data_set, batch_size=args.batch_size, shuffle=True, | |
pin_memory=torch.cuda.is_available(), | |
num_workers=args.num_workers) | |
test_data_loader = DataLoader(test_data_set, shuffle=False, | |
pin_memory=torch.cuda.is_available(), | |
num_workers=args.num_workers) | |
optimizer = torch.optim.Adam( | |
filter(lambda p: p.requires_grad, model.parameters()), | |
lr=args.lr | |
) | |
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer) | |
# adjust weights, inverse weighting to class abundance | |
weights = torch.tensor([0.33, 0.66]).to(DEVICE) | |
criterion = torch.nn.CrossEntropyLoss(weight=weights) | |
len_data = len(data_set) | |
for epoch in range(args.epochs): | |
print(f"Epoch {epoch+1}/{args.epochs}") | |
running_loss = 0.0 | |
running_corrects = 0 | |
for batch in tqdm(data_loader): | |
imgs, labels = batch | |
imgs = torch.tensor(imgs).to(DEVICE).float() | |
labels = torch.tensor(labels).to(DEVICE) | |
# | |
optimizer.zero_grad() | |
outputs = model(imgs) | |
labels = labels.view(-1) | |
loss = criterion(outputs, labels) | |
loss.backward() | |
optimizer.step() | |
# | |
running_loss += loss.item() | |
_, preds = torch.max(outputs.data, 1) | |
running_corrects += torch.sum(preds == labels) | |
epoch_loss = running_loss / len_data | |
epoch_acc = float(int(running_corrects) / int(len_data)) | |
print(f"epoch acc = {epoch_acc}") | |
print(f"epoch loss = {epoch_loss}") | |
print("Testing model...") | |
model = model.eval() | |
running_corrects = 0.0 | |
len_data = len(test_data_set) | |
for i, batch in enumerate(test_data_loader): | |
imgs, labels = batch | |
imgs = torch.tensor(imgs).to(DEVICE).float() | |
labels = torch.tensor(labels).to(DEVICE) | |
# | |
optimizer.zero_grad() | |
outputs = model(imgs) | |
softmax_output = softmax(outputs.data.cpu().numpy())[0] | |
labels = labels.view(-1) | |
_, preds = torch.max(outputs.data, 1) | |
running_corrects += torch.sum(preds == labels) | |
imgs = imgs.cpu().numpy()[0,...] | |
# move colour channels to the end for matplotlib | |
imgs = np.transpose(imgs, (1, 2, 0)) | |
plt.figure(i) | |
plt.imshow(imgs, cmap=plt.cm.gray) | |
plt.title(f"Blob = {softmax_output[0]}\nBleb = {softmax_output[1]}") | |
plt.show() | |
test_acc = float(int(running_corrects) / int(len_data)) | |
print(f"Test acc = {test_acc}") | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment