Last active
October 13, 2021 01:24
-
-
Save fxsjy/5574345 to your computer and use it in GitHub Desktop.
mnist with sklearn
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy | |
import random | |
from numpy import arange | |
#from classification import * | |
from sklearn import metrics | |
from sklearn.datasets import fetch_mldata | |
from sklearn.ensemble import RandomForestClassifier | |
from sklearn.utils import shuffle | |
import time | |
def run(): | |
mnist = fetch_mldata('MNIST original') | |
#mnist.data, mnist.target = shuffle(mnist.data, mnist.target) | |
#print mnist.data.shape | |
# Trunk the data | |
n_train = 60000 | |
n_test = 10000 | |
# Define training and testing sets | |
indices = arange(len(mnist.data)) | |
random.seed(0) | |
#train_idx = random.sample(indices, n_train) | |
#test_idx = random.sample(indices, n_test) | |
train_idx = arange(0,n_train) | |
test_idx = arange(n_train+1,n_train+n_test) | |
X_train, y_train = mnist.data[train_idx], mnist.target[train_idx] | |
X_test, y_test = mnist.data[test_idx], mnist.target[test_idx] | |
# Apply a learning algorithm | |
print "Applying a learning algorithm..." | |
clf = RandomForestClassifier(n_estimators=10,n_jobs=2) | |
clf.fit(X_train, y_train) | |
# Make a prediction | |
print "Making predictions..." | |
y_pred = clf.predict(X_test) | |
#print y_pred | |
# Evaluate the prediction | |
print "Evaluating results..." | |
print "Precision: \t", metrics.precision_score(y_test, y_pred) | |
print "Recall: \t", metrics.recall_score(y_test, y_pred) | |
print "F1 score: \t", metrics.f1_score(y_test, y_pred) | |
print "Mean accuracy: \t", clf.score(X_test, y_test) | |
if __name__ == "__main__": | |
start_time = time.time() | |
results = run() | |
end_time = time.time() | |
print "Overall running time:", end_time - start_time |
Had the same issue witth fetch_mldata(). After reading this SF answer https://stackoverflow.com/a/51301798/433717 I downloaded it from Kaggle.
There is a problem that if I use the original .idx file, transform it into the matrix and train it, the time is so long. Any better solution?
i can't import the dataset whenever i try to do so this error pops up
ImportError: cannot import name 'fetch_mldata' from 'sklearn.datasets'
i can't import the dataset whenever i try to do so this error pops up
ImportError: cannot import name 'fetch_mldata' from 'sklearn.datasets'
use fetch_openml inplace of fetch_mldata and use 'mnist_784' inplace of 'MNIST original'
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
it is showing
"RemoteDisconnected: Remote end closed connection without response"
error in jupyter notebook can you say what is happening