Last active
September 7, 2021 19:07
-
-
Save talhaanwarch/13ffc9f14043ab7933899f41a8996bb5 to your computer and use it in GitHub Desktop.
PL segmentation gist
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# data link | |
#https://drive.google.com/file/d/1EwjJx-V-Gq7NZtfiT6LZPLGXD2HN--qT/view?usp=sharing | |
from glob import glob | |
import cv2 | |
from PIL import Image | |
import os | |
import pandas as pd | |
import numpy as np | |
from sklearn.model_selection import train_test_split | |
from matplotlib import pyplot as plt | |
path='D:/image/classification/2D/PSL/data/segmentation/eyth_dataset/' | |
# In[3]: | |
# # data issue, the two mask folders images has no ext with it | |
# for i in glob(path+'masks/vid4/**'): | |
# os.rename(i, i+'.png') | |
# for i in glob(path+'masks/vid9/**'): | |
# os.rename(i, i+'.png') | |
# In[4]: | |
def get_path(): | |
images=sorted([glob(i+'*.jpg') for i in glob(path+'/images/*/')]) | |
images = sorted([item for sublist in images for item in sublist]) | |
masks=sorted([glob(i+'*.png') for i in glob(path+'/masks/*/')]) | |
masks = sorted([item for sublist in masks for item in sublist]) | |
return images,masks | |
images,masks=get_path() | |
# In[5]: | |
data_dicts = [ | |
{"image": image_name, "label": label_name} | |
for image_name, label_name in zip(images, masks) | |
] | |
print(len(data_dicts)) | |
train_files, val_files=train_test_split(data_dicts,test_size=0.2,random_state=21) | |
len(train_files),len(val_files) | |
# In[6]: | |
import cv2 | |
image=cv2.imread(train_files[0]['image']) | |
image=cv2.cvtColor(image,cv2.COLOR_BGR2RGB) | |
mask=cv2.imread(train_files[0]['label'],0) | |
print(image.shape) | |
print(mask.max(),mask.min()) | |
fig,ax=plt.subplots(1,2) | |
ax[0].imshow(image) | |
ax[1].imshow(mask,cmap='gray') | |
# In[7]: | |
from pytorch_lightning import seed_everything, LightningModule, Trainer | |
from pytorch_lightning.callbacks import EarlyStopping,ModelCheckpoint,LearningRateMonitor | |
from torch.utils.data import DataLoader,Dataset | |
from pytorch_lightning.loggers import TensorBoardLogger | |
from torch.optim.lr_scheduler import ReduceLROnPlateau,CosineAnnealingWarmRestarts | |
import torch.nn as nn | |
import torch | |
import torchvision | |
from torch.nn import functional as F | |
# In[8]: | |
import albumentations as A | |
from albumentations.pytorch import ToTensorV2 | |
train_aug= A.Compose([ | |
A.Resize(224,224), | |
A.HorizontalFlip(p=0.5), | |
A.Normalize(mean=(0), std=(1)), | |
ToTensorV2(p=1.0), | |
], p=1.0) | |
val_aug= A.Compose([ | |
A.Resize(224, 224), | |
A.Normalize(mean=(0), std=(1)), | |
ToTensorV2(p=1.0), | |
], p=1.0) | |
# In[9]: | |
class DataReader(Dataset): | |
def __init__(self,data,transform=None): | |
super(DataReader,self).__init__() | |
self.data=data | |
self.transform=transform | |
def __len__(self): | |
return len(self.data) | |
def __getitem__(self,index): | |
image_path=self.data[index]['image'] | |
mask_path=self.data[index]['label'] | |
image=cv2.imread(image_path) | |
mask=cv2.imread(mask_path) | |
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB ) | |
mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY ) | |
if self.transform: | |
transformed =self.transform(image=image,mask=mask) | |
image=transformed['image'] | |
mask=transformed['mask'] | |
mask=np.expand_dims(mask,0)/255 | |
return image,mask | |
# In[10]: | |
ds = DataReader(data=data_dicts, transform=train_aug) | |
loader=DataLoader(ds, batch_size=8, shuffle=True,num_workers=0) | |
batch= next(iter(loader)) | |
print(batch[0].shape,batch[1].shape) | |
print(batch[1].max(),batch[1].min()) | |
# In[11]: | |
plt.figure() | |
grid_img=torchvision.utils.make_grid(batch[0],4,4) | |
plt.imshow(grid_img.permute(1, 2, 0)) | |
plt.title('batch of images') | |
plt.figure() | |
grid_img=torchvision.utils.make_grid(batch[1],4,4) | |
plt.imshow(grid_img.permute(1, 2, 0)*255) | |
plt.title('batch of masks') | |
# In[12]: | |
from einops import rearrange | |
def dice_coef(mask_pred,mask_gt ): | |
def compute_dice_coefficient(mask_pred,mask_gt, smooth = 0.0001): | |
"""Compute soerensen-dice coefficient. | |
compute the soerensen-dice coefficient between the ground truth mask `mask_gt` | |
and the predicted mask `mask_pred`. | |
Args: | |
mask_gt: 4-dim Numpy array of type bool. The ground truth mask. [B, 1, H, W] | |
mask_pred: 4-dim Numpy array of type bool. The predicted mask. [B, C, H, W] | |
Returns: | |
the dice coeffcient as float. If both masks are empty, the result is NaN | |
""" | |
volume_sum = mask_gt.sum() + mask_pred.sum() | |
volume_intersect = (mask_gt * mask_pred).sum() | |
return (2*volume_intersect+smooth) / (volume_sum+smooth) | |
dice=0 | |
n_pred_ch = mask_pred.shape[1] | |
mask_pred=torch.softmax(mask_pred, 1) | |
mask_gt=F.one_hot(mask_gt.long(), num_classes=n_pred_ch) #create one hot vector | |
mask_gt=rearrange(mask_gt, 'd0 d1 d2 d3 d4 -> d0 (d1 d4) d2 d3 ') #reshape one hot vector | |
for ind in range(0,n_pred_ch): | |
dice += compute_dice_coefficient(mask_gt[:,ind,:,:], mask_pred[:,ind,:,:]) | |
return dice/n_pred_ch # taking average | |
# In[13]: | |
from einops import rearrange | |
def dice_loss(pred,true, softmax=True,sigmoid=False,one_hot=True,background=True,smooth = 0.0001): | |
""" | |
pred: predicted values without applying any activation at the end | |
shape (B,C,H,W) example: (4, 59, 512, 512) | |
true: ground truth shape (B,1,H,W) example: (4, 1, 512, 512) | |
softmax: for multiclass | |
sigmoid: for binaryclass | |
one_hot: convert true values to one hot encoded | |
background: calculate background | |
""" | |
n_pred_ch = pred.shape[1] | |
if softmax: | |
assert n_pred_ch!=1, 'single channel found' | |
pred=torch.softmax(pred, 1) | |
if sigmoid: | |
pred=torch.sigmoid(pred, 1) | |
if one_hot: | |
assert n_pred_ch!=1, 'single channel found' | |
true=F.one_hot(true.long(), num_classes=n_pred_ch) | |
true=rearrange(true, 'd0 d1 d2 d3 d4 -> d0 (d1 d4) d2 d3 ') | |
if background is False: | |
assert one_hot!=True, 'apply one hot encode ' | |
true = true[:, 1:] | |
pred = pred[:, 1:] | |
reduce_axis=torch.arange(1, len(true.shape)).tolist()# reducing only spatial dimensions (not batch nor channels) | |
intersection = torch.sum(true * pred, dim=reduce_axis) | |
denominator = torch.sum(true, dim=reduce_axis) + torch.sum(pred, dim=reduce_axis) | |
dice= (2.0 * intersection + smooth) / (denominator + smooth) | |
return 1.0 - torch.mean(dice) # the batch and channel average | |
# In[14]: | |
from einops import rearrange | |
def focal_dice_loss(pred,true,softmax=True,alpha=0.5,gamma=2): | |
""" | |
pred: predicted values without applying any activation at the end | |
shape (B,C,H,W) example: (4, 59, 512, 512) | |
true: ground truth shape (B,1,H,W) example: (4, 1, 512, 512) | |
""" | |
n_pred_ch = pred.shape[1] | |
if softmax: | |
assert n_pred_ch!=1, 'single channel found' | |
pred=torch.softmax(pred, 1) | |
celoss= F.cross_entropy(pred, torch.squeeze(true, dim=1).long(),reduction='none') | |
celoss=torch.exp(-celoss) | |
focal_loss = alpha * (1-celoss)**gamma * celoss | |
focal_loss=torch.mean(focal_loss) | |
diceloss=dice_loss(pred,true,softmax=False)#softmax false, beacuase already applied | |
return 0.5*focal_loss+0.5*diceloss | |
# In[15]: | |
import segmentation_models_pytorch as smp | |
import torchmetrics | |
# iou(preds_array,labels_array.type(torch.int).to('cuda')) | |
class OurModel(LightningModule): | |
def __init__(self): | |
super(OurModel,self).__init__() | |
#architecute | |
self.layer = smp.Unet( | |
encoder_name='resnet18', | |
encoder_weights='imagenet', | |
in_channels=3, | |
classes=2, | |
) | |
#parameters | |
self.lr=1e-3 | |
self.batch_size=32 | |
self.numworker=0 | |
self.iou=torchmetrics.IoU(2) | |
def forward(self,x): | |
return self.layer(x) | |
def configure_optimizers(self): | |
opt=torch.optim.AdamW(self.parameters(), lr=self.lr,weight_decay=1e-5) | |
scheduler = CosineAnnealingWarmRestarts(opt,T_0=10, T_mult=1, eta_min=1e-5, last_epoch=-1) | |
return {'optimizer': opt,'lr_scheduler':scheduler} | |
def train_dataloader(self): | |
ds = DataReader(data=train_files, transform=train_aug) | |
loader=DataLoader(ds, batch_size=self.batch_size, shuffle=True,num_workers=self.numworker) | |
return loader | |
def training_step(self,batch,batch_idx): | |
image,segment=batch[0], batch[1] | |
out=self(image) | |
loss=focal_dice_loss(out,segment) | |
dice=dice_coef(out,segment) | |
iouscore=self.iou(out,segment.type(torch.int8)) | |
return {'loss':loss,'iou':iouscore,'dice':dice} | |
def training_epoch_end(self, outputs): | |
loss=torch.stack([x["loss"] for x in outputs]).mean().detach().cpu().numpy().round(2) | |
iou=torch.stack([x["iou"] for x in outputs]).mean().detach().cpu().numpy().round(2) | |
dice=torch.stack([x["dice"] for x in outputs]).mean().detach().cpu().numpy().round(2) | |
print('training loss, iou, dice ',loss, iou, dice) | |
def val_dataloader(self): | |
ds = DataReader(data=val_files, transform=val_aug) | |
loader=DataLoader(ds, batch_size=self.batch_size, shuffle=False,num_workers=self.numworker) | |
return loader | |
def validation_step(self,batch,batch_idx): | |
image,segment=batch[0], batch[1] | |
out=self(image) | |
loss=focal_dice_loss(out,segment) | |
dice=dice_coef(out,segment) | |
iouscore=self.iou(out,segment.type(torch.int8)) | |
return {'loss':loss,'iou':iouscore,'dice':dice} | |
def validation_epoch_end(self, outputs): | |
loss=torch.stack([x["loss"] for x in outputs]).mean().detach().cpu().numpy().round(2) | |
iou=torch.stack([x["iou"] for x in outputs]).mean().detach().cpu().numpy().round(2) | |
dice=torch.stack([x["dice"] for x in outputs]).mean().detach().cpu().numpy().round(2) | |
print('validation loss, iou, dice ',loss, iou, dice) | |
# In[16]: | |
model = OurModel() | |
# In[17]: | |
lr_monitor = LearningRateMonitor(logging_interval='epoch') | |
checkpoint_callback = ModelCheckpoint(monitor='val_loss',dirpath='unet', | |
filename='checkpoint') | |
trainer = Trainer(max_epochs=5, | |
gpus=-1,precision=16, | |
stochastic_weight_avg=True, | |
) | |
# In[18]: | |
trainer.fit(model) | |
# In[ ]: | |
trainer.validate(model) | |
# In[ ]: | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment