Skip to content

Instantly share code, notes, and snippets.

@ogrisel
Created December 16, 2012 21:52
Show Gist options
  • Save ogrisel/4313514 to your computer and use it in GitHub Desktop.
Save ogrisel/4313514 to your computer and use it in GitHub Desktop.
Gist to load the amazon7 dataset as a scipy.sparse matrix with hashed features
"""Utility script to load the Amazon 7 Sentiment Analysis Dataset
http://www.cs.jhu.edu/~mdredze/datasets/sentiment/
"""
import HTMLParser
import numpy as np
import scipy.sparse as sp
def _go_to(iterator, expected, raise_if=()):
"""Scan bytes lines to a specific expected line"""
collected = []
for item in iterator:
if item == expected:
return collected
elif item in raise_if:
raise ValueError("Unexpected %r while seeking for %r" % (
item, expected))
else:
collected.append(item)
def parse_review_file(file_path, vectorizer=None, batch_size=1000,
max_negative=2, min_positive=4):
"""Extract text document data and label info from the pseudo XML file"""
html_parser = HTMLParser.HTMLParser()
documents = []
class_labels = []
vectorized = []
size_bytes = 0
invalid_before_rating = set([
"</review>\n",
"<review_text>\n",
"<title>\n",
])
invalid_before_title = set([
"</review>\n",
"<review_text>\n",
])
with open(file_path, 'rb') as f:
line_iterator = iter(f)
try:
while True:
# Scan the byte line making strong assumption on the tag
# ordering and indenting to speed up the parsing process
_go_to(line_iterator, "<review>\n")
_go_to(line_iterator, "<rating>\n",
raise_if=invalid_before_rating)
rating = float(line_iterator.next().strip())
_go_to(line_iterator, "<title>\n",
raise_if=invalid_before_title)
title_lines = _go_to(line_iterator, "</title>\n",
raise_if=("</review>\n",))
_go_to(line_iterator, "<review_text>\n",
raise_if=("</review>\n",))
review_text_lines = _go_to(line_iterator, "</review_text>\n",
raise_if=("</review>\n",))
if rating == 3:
# Skip neutral reviews
continue
# Concatenate lines to form a document
lines = title_lines
lines += ["\n"]
lines += review_text_lines
document = "".join(lines)
# Measure the size in bytes
size_bytes += len(document)
# Decode the charset and unescape the XML entities
document = document.decode('latin1')
document = html_parser.unescape(document)
documents.append(document)
if rating > 3:
class_labels.append(1)
else:
class_labels.append(-1)
if (vectorizer is not None
and len(documents)
and len(documents) % batch_size == 0):
# Vectorize the current batch to release the memory
vectorized.append(vectorizer.transform(documents))
documents = []
except StopIteration:
pass
if vectorizer is not None:
vectorized.append(vectorizer.transform(documents))
documents = sp.vstack(vectorized)
return documents, np.array(class_labels), size_bytes
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment