Skip to content

Instantly share code, notes, and snippets.

@mitallast
Last active October 16, 2019 11:50
Show Gist options
  • Save mitallast/87f0d0c5a8e5447c1626 to your computer and use it in GitHub Desktop.
Save mitallast/87f0d0c5a8e5447c1626 to your computer and use it in GitHub Desktop.
Example Naive Bayes Classifier with Apache Spark Pipeline
+--------+--------------------+-----+--------------------+--------------------+--------------------+--------------------+----------+
|category| text|label| words| features| rawPrediction| probability|prediction|
+--------+--------------------+-----+--------------------+--------------------+--------------------+--------------------+----------+
| 3001|Плойки и наборы V...| 24.0|[плойки, и, набор...|(10000,[326,796,1...|[-174.67716870697...|[6.63481663197049...| 24.0|
| 833|"Чехол-обложка дл...| 1.0|["чехол-обложка, ...|(10000,[514,986,1...|[-379.37151502387...|[5.32678001676623...| 1.0|
| 833|"Чехол-обложка дл...| 1.0|["чехол-обложка, ...|(10000,[514,986,1...|[-379.84825219376...|[2.15785456821554...| 1.0|
| 833|"Чехол-обложка дл...| 1.0|["чехол-обложка, ...|(10000,[290,514,9...|[-395.42735009477...|[6.44323423370500...| 1.0|
| 833|"Чехол-обложка дл...| 1.0|["чехол-обложка, ...|(10000,[290,514,9...|[-396.10251348944...|[6.31147674177529...| 1.0|
| 343|"HP SD SDHC 32GB ...| 4.0|["hp, sd, sdhc, 3...|(10000,[682,728,9...|[-257.91503332110...|[1.39061762573886...| 4.0|
| 833|Накладка на задню...| 1.0|[накладка, на, за...|(10000,[52,262,43...|[-312.52529589453...|[1.36384892167512...| 1.0|
| 9|Шина Nordman RS 1...| 8.0|[шина, nordman, r...|(10000,[1124,1223...|[-70.874699878463...|[8.66152284697728...| 8.0|
| 9|Шина Cordiant Pol...| 8.0|[шина, cordiant, ...|(10000,[50,1125,1...|[-88.024541535325...|[4.88331956423802...| 8.0|
| 9|Шина Amtel NordMa...| 8.0|[шина, amtel, nor...|(10000,[1125,1392...|[-81.198262280603...|[3.17695885661577...| 8.0|
| 9|покрышка Пирелли ...| 8.0|[покрышка, пирелл...|(10000,[104,1786,...|[-71.071831456444...|[4.16786508937252...| 8.0|
| 9|Шина Tigar SIGURA...| 8.0|[шина, tigar, sig...|(10000,[1125,1392...|[-68.318663189674...|[4.99027130148263...| 8.0|
| 9|Шина Kumho 7400 1...| 8.0|[шина, kumho, 740...|(10000,[13,378,11...|[-71.985129861611...|[1.72668174976181...| 8.0|
| 9|Шина Tigar SIGURA...| 8.0|[шина, tigar, sig...|(10000,[1125,3523...|[-69.462434759853...|[2.12061545878723...| 8.0|
| 0|"Стилус для HTC T...| 0.0|["стилус, для, ht...|(10000,[45,575,73...|[-352.85875748691...|[0.99999997337963...| 0.0|
| 833|Корпус для Nokia ...| 1.0|[корпус, для, nok...|(10000,[68,290,57...|[-127.77659993271...|[2.79347683499863...| 1.0|
| 0|Клавиатура Oklick...| 0.0|[клавиатура, okli...|(10000,[770,837,1...|[-235.09281306068...|[1.82375120886902...| 3.0|
| 343|"Флешка SanDisk C...| 4.0|["флешка, sandisk...|(10000,[56,210,30...|[-446.81306611481...|[3.06096855439347...| 4.0|
| 833|"Накладка на задн...| 1.0|["накладка, на, з...|(10000,[53,262,43...|[-271.90628440262...|[1.88877608249557...| 1.0|
| 833|Чехол раскладной ...| 1.0|[чехол, раскладно...|(10000,[171,182,1...|[-214.65871702691...|[4.31375890687199...| 1.0|
+--------+--------------------+-----+--------------------+--------------------+--------------------+--------------------+----------+
only showing top 20 rows
>>>
>>>
>>> print "F1 metric = %g" % metric
F1 metric = 0.842165
from pyspark.mllib.regression import LabeledPoint
from pyspark.ml.classification import NaiveBayes
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from pyspark.ml.feature import HashingTF, Tokenizer, StringIndexer
from pyspark.ml import Pipeline
from pyspark.sql import Row
textFile = sc.textFile("/Users/mitallast/Sites/spark/sell.csv")
data = textFile.map(lambda line: line.split(',', 1)).map(lambda p: Row(category=p[0], text=p[1]))
schemaSell = sqlContext.createDataFrame(data)
schemaSell.write.save("/Users/mitallast/Sites/spark/sell.parquet", format="parquet")
schemaSell = sqlContext.read.load("/Users/mitallast/Sites/spark/sell.parquet")
train_data, test_data = schemaSell.randomSplit([0.8, 0.2])
categoryIndexer = StringIndexer(inputCol="category", outputCol="label")
tokenizer = Tokenizer(inputCol="text", outputCol="words")
hashingTF = HashingTF(inputCol="words", outputCol="features", numFeatures=10000)
nb = NaiveBayes(smoothing=1.0, modelType="multinomial")
categoryConverter = IndexToString(inputCol="prediction", outputCol="predCategory", labels=categoryIndexer.labels)
pipeline = Pipeline(stages=[categoryIndexer, tokenizer, hashingTF, nb, categoryConverter])
model = pipeline.fit(train_data)
pr = model.transform(test_data)
evaluator = MulticlassClassificationEvaluator(labelCol="label", predictionCol="prediction", metricName="f1")
metric = evaluator.evaluate(pr)
print "F1 metric = %g" % metric
@Dirkster99
Copy link

Dirkster99 commented Oct 16, 2019

I think the input csv file should have only 2 columns: category and text. You can find sample data for this all other the Internet...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment