This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """ | |
| LightGBM vs TabM binary-classification benchmark with optimized training | |
| Requirements: pip install lightgbm tabm torch pandas scikit-learn tqdm | |
| !git clone https://github.com/Diyago/Tabular-data-generation.git | |
| !mv Tabular-data-generation/Research/data/* data/ | |
| """ | |
| """ | |
| LightGBM vs TabM vs RealMLP binary-classification benchmark with optimized training | |
| Requirements: pip install lightgbm tabm torch pandas scikit-learn tqdm "pytabkit[models]" |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| class SimpleGNN(torch.nn.Module): | |
| """Original from http://pages.di.unipi.it/citraro/files/slides/Landolfi_tutorial.pdf""" | |
| def __init__(self, dataset, hidden=64, layers=6): | |
| super(SimpleGNN, self).__init__() | |
| self.dataset = dataset | |
| self.convs = torch.nn.ModuleList() | |
| self.convs.append(GCNConv(in_channels=dataset.num_node_features, | |
| out_channels=hidden)) | |
| for _ in range(1, layers): |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| augmented = aug_with_crop(image_size = 1024)(image=img, mask=mask) | |
| image_aug = augmented['image'] | |
| mask_aug = augmented['mask'] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def aug_with_crop(image_size = 256, crop_prob = 1): | |
| return Compose([ | |
| RandomCrop(width = image_size, height = image_size, p=crop_prob), | |
| HorizontalFlip(p=0.5), | |
| VerticalFlip(p=0.5), | |
| RandomRotate90(p=0.5), | |
| Transpose(p=0.5), | |
| ShiftScaleRotate(shift_limit=0.01, scale_limit=0.04, rotate_limit=0, p=0.25), | |
| RandomBrightnessContrast(p=0.5), | |
| RandomGamma(p=0.25), |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| model = Unet(backbone_name = 'efficientnetb0', encoder_weights='imagenet', encoder_freeze = False) | |
| model.compile(optimizer = Adam(), loss=bce_jaccard_loss, metrics=[iou_score]) | |
| history = model.fit_generator(train_generator, shuffle =True, | |
| epochs=50, workers=4, use_multiprocessing=True, | |
| validation_data = test_generator, | |
| verbose = 1, callbacks=callbacks) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from keras.callbacks import ModelCheckpoint, ReduceLROnPlateau, EarlyStopping, TensorBoard | |
| # reduces learning rate on plateau | |
| lr_reducer = ReduceLROnPlateau(factor=0.1, | |
| cooldown= 10, | |
| patience=10,verbose =1, | |
| min_lr=0.1e-5) | |
| # model autosave callbacks | |
| mode_autosave = ModelCheckpoint("./weights/road_crop.efficientnetb0imgsize.h5", | |
| monitor='val_iou_score', |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| test_generator = DataGeneratorFolder(root_dir = './data/road_segmentation_ideal/training', | |
| image_folder = 'input/', | |
| mask_folder = 'output/', | |
| nb_y_features = 1) | |
| train_generator = DataGeneratorFolder(root_dir = './data/road_segmentation_ideal/training', | |
| image_folder = 'input/', | |
| mask_folder = 'output/', | |
| batch_size=4, | |
| image_size=512, |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def __getitem__(self, index): | |
| data_index_min = int(index*self.batch_size) | |
| data_index_max = int(min((index+1)*self.batch_size, len(self.image_filenames))) | |
| indexes = self.image_filenames[data_index_min:data_index_max] | |
| this_batch_size = len(indexes) # The last batch can be smaller than the others | |
| X = np.empty((this_batch_size, self.image_size, self.image_size, 3), dtype=np.float32) | |
| y = np.empty((this_batch_size, self.image_size, self.image_size, self.nb_y_features), dtype=np.uint8) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def __init__(self, root_dir=r'../data/val_test', image_folder='img/', mask_folder='masks/', | |
| batch_size=1, image_size=768, nb_y_features=1, | |
| augmentation=None, | |
| suffle=True): | |
| self.image_filenames = listdir_fullpath(os.path.join(root_dir, image_folder)) | |
| self.mask_names = listdir_fullpath(os.path.join(root_dir, mask_folder)) | |
| self.batch_size = batch_size | |
| self.augmentation = augmentation | |
| self.image_size = image_size | |
| self.nb_y_features = nb_y_features |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| conda install -c conda-forge keras | |
| pip install git+https://github.com/qubvel/efficientnet | |
| pip install git+https://github.com/qubvel/classification_models.git | |
| pip install git+https://github.com/qubvel/segmentation_models | |
| pip install git+https://github.com/albu/albumentations | |
| pip install tta-wrapper |
NewerOlder