Skip to content

Instantly share code, notes, and snippets.

Created May 1, 2014 15:07
Show Gist options
  • Save ralphbean/4c2d4105ea2c7e407fb5 to your computer and use it in GitHub Desktop.
Save ralphbean/4c2d4105ea2c7e407fb5 to your computer and use it in GitHub Desktop.
Messing around with linear regression over text data
""" Messing around with scikit-learn. """
import sys
import numpy as np
import scipy.sparse
import sklearn.linear_model
import sklearn.datasets
import sklearn.svm
import sklearn.metrics
import sklearn.decomposition
import sklearn.feature_extraction.text
import sklearn.utils.sparsefuncs
# The fetch_20newsgroups dataset uses logging that we need to setup.
import logging
dataset = sklearn.datasets.fetch_20newsgroups(
corpus =
n_samples = len(corpus)
target = [[0] * n_samples] * n_samples
for i in range(n_samples):
target[i][[i]] = 1.0
print "* shape of the corpus", len(corpus)
print "Convert text data into numerical vectors"
vectorizer = sklearn.feature_extraction.text.CountVectorizer(
ngram_range=(1, 1), #ngram_range=(1, 1) is the default
data = vectorizer.fit_transform(corpus)
print "* shape of the tfidf vectors", data.shape
# Save this to compute explained variance later
vectors = data
print "Reduce the dimensionality of the data"
pca = sklearn.decomposition.TruncatedSVD(n_components=50)
data = pca.fit_transform(data)
print "* shape of the pca components", data.shape
exp = np.var(data, axis=0)
full = sklearn.utils.sparsefuncs.mean_variance_axis0(vectors)[1].sum()
explained_variance_ratios = exp / full
confidence = sum(explained_variance_ratios)
if confidence < 0.8:
print "explained variance ratio %f < 0.8. Bailing." % confidence
print "Training a support vector machine on first half"
regression = sklearn.linear_model.LinearRegression()[:n_samples / 2], target[:n_samples / 2])
print "Now predict the value of on the second half"
expected = target[n_samples / 2:]
predicted = regression.predict(data[n_samples / 2:])
"Regression report for regression %s:\n%s\n"
% (regression, sklearn.metrics.mean_squared_error(expected, predicted)))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment