Skip to content

Instantly share code, notes, and snippets.

@phizaz
Created March 24, 2021 06:28
Show Gist options
  • Save phizaz/72f18cfb917e6864821f52f69a71f91c to your computer and use it in GitHub Desktop.
Save phizaz/72f18cfb917e6864821f52f69a71f91c to your computer and use it in GitHub Desktop.
MIMIC's chest x-ray for pair image and text dataset
import os
import random
import cv2
import pandas as pd
import torch
from torch.utils.data import Dataset
from transformers import (BertTokenizerFast, PreTrainedTokenizerFast,
RobertaTokenizerFast)
from .base_cls import *
from .cxr_augment import *
from .mimic_cls_data import IGNORE_FILES, SPLITS, VIEWS
from .nih14_data import cv2_loader
cv2.setNumThreads(1)
here = os.path.dirname(__file__)
DATASET_PATH = f'{here}/datasets/mimic-cxr'
REPORT_TYPES = {
'finding_only': f'{DATASET_PATH}/reports_findings_only.csv',
'finding_impression':
f'{DATASET_PATH}/reports_findings_or_impressions.csv',
'abnormal': f'{DATASET_PATH}/reports_abnormal.csv',
'abnormal_nodiff': f'{DATASET_PATH}/reports_abnormal_nodiff.csv',
}
@dataclass
class MimicTextDataConfig(DatasetConfig):
dataset: str = 'mimic-text'
tokenizer: PreTrainedTokenizerFast = None
n_max_sentence_length: int = None
lower_case: bool = True
blank_line_for_empty_report: bool = True
img_dir: str = f'{DATASET_PATH}/images512'
split: str = 'v1'
view: str = 'front'
report: str = 'finding_impression'
trans_conf: TransformConfig = TransformConfig()
@property
def name(self):
name = f'{self.split}-{self.view}-{self.report}'
if self.n_max_sentence_length is not None:
name += f'-len{self.n_max_sentence_length}'
if not self.lower_case:
name += f'-upper'
name += f'_{self.trans_conf.name}'
return name
class MimicTextCombinedDataset:
def __init__(self, conf: MimicTextDataConfig):
train_csv, val_csv, test_csv = SPLITS[conf.split]
train_transform = make_transform('train', conf.trans_conf)
eval_transform = make_transform('eval', conf.trans_conf)
self.train_data = MimicTextDataset(f'{DATASET_PATH}/{train_csv}', conf,
train_transform)
self.val_data = MimicTextDataset(f'{DATASET_PATH}/{val_csv}', conf,
eval_transform)
self.test_data = MimicTextDataset(f'{DATASET_PATH}/{test_csv}', conf,
eval_transform)
class MimicTextDataset(Dataset):
def __init__(
self,
split_csv,
conf: MimicTextDataConfig,
transform=None,
):
self.conf = conf
# make the df
split_df = pd.read_csv(split_csv)
report_df = pd.read_csv(REPORT_TYPES[conf.report])
df = pd.read_csv(VIEWS[conf.view])
# select only mentioned
df = df[df['study_id'].isin(split_df['study_id'])]
# select only those we have the reports
report_study_id = set(report_df['study_id'])
df = df[df['study_id'].isin(report_study_id)]
# select only we have readable images
df = df[~df['dicom_id'].isin(IGNORE_FILES)].reset_index(drop=True)
self.report_df = report_df
self.record_df = df
self.transform = transform
def __len__(self):
return len(self.record_df)
def __getitem__(self, idx):
############
# REPORT
study_id = self.record_df.loc[idx, 'study_id']
text = self.report_df[self.report_df['study_id'] ==
study_id].iloc[0]['text']
if text != text:
# nan = empty report
text = ''
if self.conf.lower_case:
text = text.lower()
if not self.conf.blank_line_for_empty_report and text == '':
# empty report
input_ids = []
else:
lines = text.split('\n')
# random then take the first line
random.shuffle(lines)
lines = lines[0]
if isinstance(self.conf.tokenizer, BertTokenizerFast):
# bert tokenizer bug
# [''] => error
# so we need to force add special tokens
res = self.conf.tokenizer(lines,
return_attention_mask=False,
add_special_tokens=True)
elif isinstance(self.conf.tokenizer, RobertaTokenizerFast):
res = self.conf.tokenizer(lines,
return_attention_mask=False,
add_special_tokens=True)
else:
raise NotImplementedError()
input_ids = res['input_ids']
# clip the length
if self.conf.n_max_sentence_length is not None:
input_ids = input_ids[:self.conf.n_max_sentence_length]
###########
# IMAGE
# we use the png files
img_path = self.record_df.loc[idx, 'path'].replace('.dcm', '.png')
# remove the prefix files/
img_path = img_path.replace('files/', '')
img_path = f'{self.conf.img_dir}/{img_path}'
img = cv2_loader(img_path)
if self.transform:
_res = self.transform(image=img, bboxes=[])
img = _res['image']
return {
'img': img,
'input_ids': input_ids,
'study_id': study_id,
}
def add_bos_token(sentences, bos_token_id):
out = []
for each in sentences:
out.append([bos_token_id] + each)
return out
def add_eos_token(sentences, eos_token_id):
out = []
for each in sentences:
if each[-1] != eos_token_id:
each = each + [eos_token_id]
out.append(each)
return out
class MimicTextCollator:
def __init__(self, conf):
self.conf = conf
def __call__(self, data):
out = {'img': [], 'input_ids': []}
max_length = max(len(each['input_ids']) for each in data)
for each in data:
out['img'].append(each['img'])
n_pad = max_length - len(each['input_ids'])
pad = [self.conf.tokenizer.pad_token_id] * n_pad
out['input_ids'].append(each['input_ids'] + pad)
out['img'] = torch.stack(out['img'])
out['input_ids'] = torch.LongTensor(out['input_ids'])
return out
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment