This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| test_generator = DataGeneratorFolder(root_dir = './data/road_segmentation_ideal/training', | |
| image_folder = 'input/', | |
| mask_folder = 'output/', | |
| nb_y_features = 1) | |
| train_generator = DataGeneratorFolder(root_dir = './data/road_segmentation_ideal/training', | |
| image_folder = 'input/', | |
| mask_folder = 'output/', | |
| batch_size=4, | |
| image_size=512, |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from keras.callbacks import ModelCheckpoint, ReduceLROnPlateau, EarlyStopping, TensorBoard | |
| # reduces learning rate on plateau | |
| lr_reducer = ReduceLROnPlateau(factor=0.1, | |
| cooldown= 10, | |
| patience=10,verbose =1, | |
| min_lr=0.1e-5) | |
| # model autosave callbacks | |
| mode_autosave = ModelCheckpoint("./weights/road_crop.efficientnetb0imgsize.h5", | |
| monitor='val_iou_score', |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| model = Unet(backbone_name = 'efficientnetb0', encoder_weights='imagenet', encoder_freeze = False) | |
| model.compile(optimizer = Adam(), loss=bce_jaccard_loss, metrics=[iou_score]) | |
| history = model.fit_generator(train_generator, shuffle =True, | |
| epochs=50, workers=4, use_multiprocessing=True, | |
| validation_data = test_generator, | |
| verbose = 1, callbacks=callbacks) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def aug_with_crop(image_size = 256, crop_prob = 1): | |
| return Compose([ | |
| RandomCrop(width = image_size, height = image_size, p=crop_prob), | |
| HorizontalFlip(p=0.5), | |
| VerticalFlip(p=0.5), | |
| RandomRotate90(p=0.5), | |
| Transpose(p=0.5), | |
| ShiftScaleRotate(shift_limit=0.01, scale_limit=0.04, rotate_limit=0, p=0.25), | |
| RandomBrightnessContrast(p=0.5), | |
| RandomGamma(p=0.25), |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| augmented = aug_with_crop(image_size = 1024)(image=img, mask=mask) | |
| image_aug = augmented['image'] | |
| mask_aug = augmented['mask'] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| class SimpleGNN(torch.nn.Module): | |
| """Original from http://pages.di.unipi.it/citraro/files/slides/Landolfi_tutorial.pdf""" | |
| def __init__(self, dataset, hidden=64, layers=6): | |
| super(SimpleGNN, self).__init__() | |
| self.dataset = dataset | |
| self.convs = torch.nn.ModuleList() | |
| self.convs.append(GCNConv(in_channels=dataset.num_node_features, | |
| out_channels=hidden)) | |
| for _ in range(1, layers): |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """ | |
| LightGBM vs TabM binary-classification benchmark with optimized training | |
| Requirements: pip install lightgbm tabm torch pandas scikit-learn tqdm | |
| !git clone https://github.com/Diyago/Tabular-data-generation.git | |
| !mv Tabular-data-generation/Research/data/* data/ | |
| """ | |
| """ | |
| LightGBM vs TabM vs RealMLP binary-classification benchmark with optimized training | |
| Requirements: pip install lightgbm tabm torch pandas scikit-learn tqdm "pytabkit[models]" |
OlderNewer