Skip to content

Instantly share code, notes, and snippets.

@123epsilon
Last active August 27, 2021 18:55
Show Gist options
  • Save 123epsilon/5ac4de0e43da895470ddca807acb5670 to your computer and use it in GitHub Desktop.
Save 123epsilon/5ac4de0e43da895470ddca807acb5670 to your computer and use it in GitHub Desktop.
class Brain_MRI_Segmentation_Dataset(data.Dataset):
def __init__(self, inputs, transform=None):
self.inputs = inputs
self.transform = transform
self.input_dtype = torch.float32
self.target_dtype = torch.float32
def __len__(self):
return len(self.inputs)
def __getitem__(self, index):
#for classification return only the image and the binary label
img_path = self.inputs[index][0]
mask_path = self.inputs[index][1]
#mask_img = cv2.normalize(cv2.imread(mask_path), None, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F)
mask_img = cv2.imread(mask_path, cv2.IMREAD_UNCHANGED)
x = torch.from_numpy(np.transpose(np.array(cv2.imread(img_path)), (2,0,1))).type(self.input_dtype)
y = torch.from_numpy(np.resize(np.array(mask_img)/255., (1,256,256))).type(self.target_dtype)
if self.transform is not None:
x = self.transform(x)
y = self.transform(y)
return x,y
#load data
positive_diagnoses = [x for x in file_list if x[2] == 1]
#print(positive_diagnoses[:5])
mri_dataset = Brain_MRI_Segmentation_Dataset(positive_diagnoses)
validation_size = int(0.3 * len(mri_dataset))
train_set, val_set = data.random_split(mri_dataset, [len(mri_dataset)-validation_size, validation_size])
train_loader = data.DataLoader(dataset=train_set, batch_size=2, shuffle=True)
val_loader = data.DataLoader(dataset=val_set, batch_size=2, shuffle=False)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment