This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
class SAM(torch.optim.Optimizer): | |
def __init__(self, params, base_optimizer, rho=0.05, **kwargs): | |
assert rho >= 0.0, f"Invalid rho, should be non-negative: {rho}" | |
defaults = dict(rho=rho, **kwargs) | |
super(SAM, self).__init__(params, defaults) | |
self.base_optimizer = base_optimizer(self.param_groups, **kwargs) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from sam import SAM | |
... | |
model = YourModel() | |
base_optimizer = torch.optim.SGD # define an optimizer for the "sharpness-aware" update | |
optimizer = SAM(model.parameters(), base_optimizer, lr=0.1, momentum=0.9) | |
... | |
for input, output in data: |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from sklearn.datasets import fetch_20newsgroups | |
from sklearn.feature_extraction.text import TfidfVectorizer | |
from sklearn.naive_bayes import MultinomialNB | |
from sklearn.pipeline import make_pipeline | |
from sklearn.metrics import classification_report | |
from imblearn.under_sampling import RandomUnderSampler | |
from imblearn.pipeline import make_pipeline as make_pipeline_imb | |
from collections import Counter | |
categories = [ |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
model = make_pipeline(TfidfVectorizer(), MultinomialNB()) | |
model.fit(X_train, y_train) | |
y_pred = model.predict(X_test) | |
print(classification_report(y_test,y_pred)) | |
#precision recall f1-score support | |
# 0 0.67 0.94 0.79 319 | |
# 1 0.96 0.92 0.94 389 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
model = make_pipeline_imb(TfidfVectorizer(), RandomUnderSampler(), MultinomialNB()) | |
model.fit(X_train, y_train) | |
y_pred = model.predict(X_test) | |
print(classification_report(y_test,y_pred)) | |
# precision recall f1-score support | |
# | |
# 0 0.73 0.87 0.79 319 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import matplotlib.pyplot as plt | |
import smote_variants as sv | |
import imbalanced_databases as imbd | |
# loading the dataset | |
dataset= imbd.load_iris0() | |
features, target= dataset['data'], dataset['target'] |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# printing the number of samples before smote | |
print('majority class: %d' % np.sum(y == 0)) | |
print('minority class: %d' % np.sum(y == 1)) | |
#majority class: 100 | |
#minority class: 50 | |
#The oversampling is carried out by instantiating any oversampler implemented in the package and calling the sample function. | |
oversampler= sv.distance_SMOTE() |