Last active
April 3, 2024 13:12
-
-
Save smihael/60b101c0f04ba869da2fb345c6ae3aa3 to your computer and use it in GitHub Desktop.
Combining several imperfect classifiers to reduce overall classification error.ipynb
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from sklearn.svm import SVC | |
from sklearn.ensemble import VotingClassifier | |
from sklearn.metrics import accuracy_score | |
from sklearn.model_selection import ParameterGrid | |
# Define the search space | |
subset_seed_range = range(0, 100) # Range of seeds to try | |
param_grid = { | |
'C': [0.1, 1, 10], | |
'gamma': [0.001, 0.01, 0.1, 1], | |
'kernel': ['rbf'], | |
'random_state': range(0,5) | |
} | |
all_combinations = [(seed, params) for seed in subset_seed_range for params in ParameterGrid(param_grid)] | |
# Initialize variables to track the best performance | |
best_ensemble_accuracy = 0 | |
best_combination = None | |
n_passed = 0 | |
# Perform the search | |
for seed, params in all_combinations: | |
#print(seed) | |
#print(params) | |
# Create the subsets using the seed | |
np.random.seed(seed) | |
subset_indices_1 = np.random.choice(range(len(X)), size=int(len(X)/3), replace=False) | |
subset_indices_2 = np.random.choice(list(set(range(len(X))) - set(subset_indices_1)), size=int(len(X)/3), replace=False) | |
subset_indices_3 = list(set(range(len(X))) - set(subset_indices_1) - set(subset_indices_2)) | |
X_sub1, y_sub1 = X[subset_indices_1], y[subset_indices_1] | |
X_sub2, y_sub2 = X[subset_indices_2], y[subset_indices_2] | |
X_sub3, y_sub3 = X[subset_indices_3], y[subset_indices_3] | |
# Train SVC models on each subset with the given parameters | |
classifiers = [] | |
for i, (X_sub, y_sub) in enumerate([(X_sub1, y_sub1), (X_sub2, y_sub2), (X_sub3, y_sub3)]): | |
clf = SVC(**params).fit(X_sub, y_sub) | |
classifiers.append(clf) | |
# Create the ensemble | |
ensemble = VotingClassifier(estimators=[(f'clf{i}', clf) for i, clf in enumerate(classifiers)], voting='hard') | |
ensemble.fit(X, y) | |
# Evaluate the ensemble and base classifiers | |
ensemble_accuracy = accuracy_score(y, ensemble.predict(X)) | |
base_accuracies = [accuracy_score(y, clf.predict(X)) for clf in classifiers] | |
# We don't want 'bad' examples | |
passed = all(acc < ensemble_accuracy for acc in base_accuracies) | |
if passed: | |
n_passed = n_passed + 1 | |
# Check if this combination is the best so far | |
if ensemble_accuracy > best_ensemble_accuracy and passed: | |
best_ensemble_accuracy = ensemble_accuracy | |
best_combination = (seed, params) | |
#print("----") | |
# Output the results | |
print(f"Best ensemble accuracy: {best_ensemble_accuracy:.2f}") | |
print(f"Best combination: Seed={best_combination[0]}, Parameters={best_combination[1]}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment