Last active
October 16, 2019 11:50
-
-
Save mitallast/87f0d0c5a8e5447c1626 to your computer and use it in GitHub Desktop.
Example Naive Bayes Classifier with Apache Spark Pipeline
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
+--------+--------------------+-----+--------------------+--------------------+--------------------+--------------------+----------+ | |
|category| text|label| words| features| rawPrediction| probability|prediction| | |
+--------+--------------------+-----+--------------------+--------------------+--------------------+--------------------+----------+ | |
| 3001|Плойки и наборы V...| 24.0|[плойки, и, набор...|(10000,[326,796,1...|[-174.67716870697...|[6.63481663197049...| 24.0| | |
| 833|"Чехол-обложка дл...| 1.0|["чехол-обложка, ...|(10000,[514,986,1...|[-379.37151502387...|[5.32678001676623...| 1.0| | |
| 833|"Чехол-обложка дл...| 1.0|["чехол-обложка, ...|(10000,[514,986,1...|[-379.84825219376...|[2.15785456821554...| 1.0| | |
| 833|"Чехол-обложка дл...| 1.0|["чехол-обложка, ...|(10000,[290,514,9...|[-395.42735009477...|[6.44323423370500...| 1.0| | |
| 833|"Чехол-обложка дл...| 1.0|["чехол-обложка, ...|(10000,[290,514,9...|[-396.10251348944...|[6.31147674177529...| 1.0| | |
| 343|"HP SD SDHC 32GB ...| 4.0|["hp, sd, sdhc, 3...|(10000,[682,728,9...|[-257.91503332110...|[1.39061762573886...| 4.0| | |
| 833|Накладка на задню...| 1.0|[накладка, на, за...|(10000,[52,262,43...|[-312.52529589453...|[1.36384892167512...| 1.0| | |
| 9|Шина Nordman RS 1...| 8.0|[шина, nordman, r...|(10000,[1124,1223...|[-70.874699878463...|[8.66152284697728...| 8.0| | |
| 9|Шина Cordiant Pol...| 8.0|[шина, cordiant, ...|(10000,[50,1125,1...|[-88.024541535325...|[4.88331956423802...| 8.0| | |
| 9|Шина Amtel NordMa...| 8.0|[шина, amtel, nor...|(10000,[1125,1392...|[-81.198262280603...|[3.17695885661577...| 8.0| | |
| 9|покрышка Пирелли ...| 8.0|[покрышка, пирелл...|(10000,[104,1786,...|[-71.071831456444...|[4.16786508937252...| 8.0| | |
| 9|Шина Tigar SIGURA...| 8.0|[шина, tigar, sig...|(10000,[1125,1392...|[-68.318663189674...|[4.99027130148263...| 8.0| | |
| 9|Шина Kumho 7400 1...| 8.0|[шина, kumho, 740...|(10000,[13,378,11...|[-71.985129861611...|[1.72668174976181...| 8.0| | |
| 9|Шина Tigar SIGURA...| 8.0|[шина, tigar, sig...|(10000,[1125,3523...|[-69.462434759853...|[2.12061545878723...| 8.0| | |
| 0|"Стилус для HTC T...| 0.0|["стилус, для, ht...|(10000,[45,575,73...|[-352.85875748691...|[0.99999997337963...| 0.0| | |
| 833|Корпус для Nokia ...| 1.0|[корпус, для, nok...|(10000,[68,290,57...|[-127.77659993271...|[2.79347683499863...| 1.0| | |
| 0|Клавиатура Oklick...| 0.0|[клавиатура, okli...|(10000,[770,837,1...|[-235.09281306068...|[1.82375120886902...| 3.0| | |
| 343|"Флешка SanDisk C...| 4.0|["флешка, sandisk...|(10000,[56,210,30...|[-446.81306611481...|[3.06096855439347...| 4.0| | |
| 833|"Накладка на задн...| 1.0|["накладка, на, з...|(10000,[53,262,43...|[-271.90628440262...|[1.88877608249557...| 1.0| | |
| 833|Чехол раскладной ...| 1.0|[чехол, раскладно...|(10000,[171,182,1...|[-214.65871702691...|[4.31375890687199...| 1.0| | |
+--------+--------------------+-----+--------------------+--------------------+--------------------+--------------------+----------+ | |
only showing top 20 rows | |
>>> | |
>>> | |
>>> print "F1 metric = %g" % metric | |
F1 metric = 0.842165 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from pyspark.mllib.regression import LabeledPoint | |
from pyspark.ml.classification import NaiveBayes | |
from pyspark.ml.evaluation import MulticlassClassificationEvaluator | |
from pyspark.ml.feature import HashingTF, Tokenizer, StringIndexer | |
from pyspark.ml import Pipeline | |
from pyspark.sql import Row | |
textFile = sc.textFile("/Users/mitallast/Sites/spark/sell.csv") | |
data = textFile.map(lambda line: line.split(',', 1)).map(lambda p: Row(category=p[0], text=p[1])) | |
schemaSell = sqlContext.createDataFrame(data) | |
schemaSell.write.save("/Users/mitallast/Sites/spark/sell.parquet", format="parquet") | |
schemaSell = sqlContext.read.load("/Users/mitallast/Sites/spark/sell.parquet") | |
train_data, test_data = schemaSell.randomSplit([0.8, 0.2]) | |
categoryIndexer = StringIndexer(inputCol="category", outputCol="label") | |
tokenizer = Tokenizer(inputCol="text", outputCol="words") | |
hashingTF = HashingTF(inputCol="words", outputCol="features", numFeatures=10000) | |
nb = NaiveBayes(smoothing=1.0, modelType="multinomial") | |
categoryConverter = IndexToString(inputCol="prediction", outputCol="predCategory", labels=categoryIndexer.labels) | |
pipeline = Pipeline(stages=[categoryIndexer, tokenizer, hashingTF, nb, categoryConverter]) | |
model = pipeline.fit(train_data) | |
pr = model.transform(test_data) | |
evaluator = MulticlassClassificationEvaluator(labelCol="label", predictionCol="prediction", metricName="f1") | |
metric = evaluator.evaluate(pr) | |
print "F1 metric = %g" % metric |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
I think the input csv file should have only 2 columns: category and text. You can find sample data for this all other the Internet...