Skip to content

Instantly share code, notes, and snippets.

@ghostflare76
Last active May 3, 2021 07:48
Show Gist options
  • Save ghostflare76/090947b277ba1e8de5274a612f15980f to your computer and use it in GitHub Desktop.
Save ghostflare76/090947b277ba1e8de5274a612f15980f to your computer and use it in GitHub Desktop.
bbc news 기사분류
import org.apache.spark.ml.{Pipeline, PipelineModel}
import org.apache.spark.ml.classification.{RandomForestClassificationModel, RandomForestClassifier, NaiveBayes}
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
import org.apache.spark.ml.feature.{StringIndexer, HashingTF, VectorIndexer, StopWordsRemover, NGram, Word2Vec, CountVectorizer, IDF}
import org.apache.spark.ml.tuning.{CrossValidator, CrossValidatorModel, ParamGridBuilder}
import org.apache.spark.sql.Encoders
import com.johnsnowlabs.nlp.pretrained.PretrainedPipeline
import com.johnsnowlabs.nlp.SparkNLP
import com.johnsnowlabs.nlp.annotator._
import com.johnsnowlabs.nlp.base._
import com.johnsnowlabs.nlp.annotators
SparkNLP.version()
import java.net.URL
import org.apache.spark.SparkFiles
val urlfile="https://storage.googleapis.com/dataset-uploader/bbc/bbc-text.csv"
spark.sparkContext.addFile(urlfile)
val df = spark.read
.option("inferSchema", true)
.option("header", true)
.csv("file://"+SparkFiles.get("bbc-text.csv"))
val Array(train, test) = df.randomSplit(Array(0.8, 0.2), seed = 123L)
val categoryIndexer = new StringIndexer()
.setInputCol("category")
.setOutputCol("label")
.setHandleInvalid("keep") // "keep", "error" or "skip"
val documentAssembler = new DocumentAssembler()
.setInputCol("text")
.setOutputCol("document")
val token = new Tokenizer()
.setInputCols("document")
.setOutputCol("token")
val normalizer = new Normalizer()
.setInputCols("token")
.setOutputCol("normalized")
.setLowercase(true)
// .setCleanupPatterns(["[^\w\d\s]"])
.setSlangMatchCase(false)
// val stemmer = new Stemmer()
// .setInputCols("normalized")
// .setOutputCol("stem")
// .setLanguage("English")
// .setLazyAnnotator(false)
val lemmatizer = LemmatizerModel.pretrained()
.setInputCols("normalized")
.setOutputCol("lemma")
val finisher = new Finisher()
.setInputCols("lemma")
.setOutputCols("token_features")
.setOutputAsArray(true)
.setCleanAnnotations(false)
val remover = new StopWordsRemover()
.setInputCol("token_features")
.setOutputCol("filtered")
val tf = new CountVectorizer()
.setInputCol("filtered")
.setOutputCol("tf")
// .setVocabSize(3)
// .setMinDF(2)
val idf = new IDF()
.setInputCol("tf")
.setOutputCol("idf")
// .setMinDocFreq(0)
// val hashingTF = new HashingTF()
// .setNumFeatures(1000)
// .setInputCol(tokenizer.getOutputCol)
// .setOutputCol("features")
val nb = new NaiveBayes()
.setFeaturesCol("idf")
.setLabelCol("label")
val pipeline = new Pipeline()
.setStages(Array(categoryIndexer, documentAssembler, token, normalizer, lemmatizer, finisher, remover, tf, idf, nb ))
val pipelineModel = pipeline.fit(train)
val prediction = pipelineModel.transform(test)
z.show(prediction.select("normalized.result"))
z.show(prediction.select("lemma.result"))
val evaluator = new MulticlassClassificationEvaluator()
.setLabelCol("label")
.setPredictionCol("prediction")
.setMetricName("accuracy")
var acc = evaluator.evaluate(prediction)
println(s"acc: $acc")
@ghostflare76
Copy link
Author

ghostflare76 commented Apr 7, 2021

com.johnsnowlabs.nlp:spark-nlp_2.11:jar:2.4.5

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment