Last active
May 3, 2021 07:48
-
-
Save ghostflare76/090947b277ba1e8de5274a612f15980f to your computer and use it in GitHub Desktop.
bbc news 기사분류
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import org.apache.spark.ml.{Pipeline, PipelineModel} | |
import org.apache.spark.ml.classification.{RandomForestClassificationModel, RandomForestClassifier, NaiveBayes} | |
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator | |
import org.apache.spark.ml.feature.{StringIndexer, HashingTF, VectorIndexer, StopWordsRemover, NGram, Word2Vec, CountVectorizer, IDF} | |
import org.apache.spark.ml.tuning.{CrossValidator, CrossValidatorModel, ParamGridBuilder} | |
import org.apache.spark.sql.Encoders | |
import com.johnsnowlabs.nlp.pretrained.PretrainedPipeline | |
import com.johnsnowlabs.nlp.SparkNLP | |
import com.johnsnowlabs.nlp.annotator._ | |
import com.johnsnowlabs.nlp.base._ | |
import com.johnsnowlabs.nlp.annotators | |
SparkNLP.version() | |
import java.net.URL | |
import org.apache.spark.SparkFiles | |
val urlfile="https://storage.googleapis.com/dataset-uploader/bbc/bbc-text.csv" | |
spark.sparkContext.addFile(urlfile) | |
val df = spark.read | |
.option("inferSchema", true) | |
.option("header", true) | |
.csv("file://"+SparkFiles.get("bbc-text.csv")) | |
val Array(train, test) = df.randomSplit(Array(0.8, 0.2), seed = 123L) | |
val categoryIndexer = new StringIndexer() | |
.setInputCol("category") | |
.setOutputCol("label") | |
.setHandleInvalid("keep") // "keep", "error" or "skip" | |
val documentAssembler = new DocumentAssembler() | |
.setInputCol("text") | |
.setOutputCol("document") | |
val token = new Tokenizer() | |
.setInputCols("document") | |
.setOutputCol("token") | |
val normalizer = new Normalizer() | |
.setInputCols("token") | |
.setOutputCol("normalized") | |
.setLowercase(true) | |
// .setCleanupPatterns(["[^\w\d\s]"]) | |
.setSlangMatchCase(false) | |
// val stemmer = new Stemmer() | |
// .setInputCols("normalized") | |
// .setOutputCol("stem") | |
// .setLanguage("English") | |
// .setLazyAnnotator(false) | |
val lemmatizer = LemmatizerModel.pretrained() | |
.setInputCols("normalized") | |
.setOutputCol("lemma") | |
val finisher = new Finisher() | |
.setInputCols("lemma") | |
.setOutputCols("token_features") | |
.setOutputAsArray(true) | |
.setCleanAnnotations(false) | |
val remover = new StopWordsRemover() | |
.setInputCol("token_features") | |
.setOutputCol("filtered") | |
val tf = new CountVectorizer() | |
.setInputCol("filtered") | |
.setOutputCol("tf") | |
// .setVocabSize(3) | |
// .setMinDF(2) | |
val idf = new IDF() | |
.setInputCol("tf") | |
.setOutputCol("idf") | |
// .setMinDocFreq(0) | |
// val hashingTF = new HashingTF() | |
// .setNumFeatures(1000) | |
// .setInputCol(tokenizer.getOutputCol) | |
// .setOutputCol("features") | |
val nb = new NaiveBayes() | |
.setFeaturesCol("idf") | |
.setLabelCol("label") | |
val pipeline = new Pipeline() | |
.setStages(Array(categoryIndexer, documentAssembler, token, normalizer, lemmatizer, finisher, remover, tf, idf, nb )) | |
val pipelineModel = pipeline.fit(train) | |
val prediction = pipelineModel.transform(test) | |
z.show(prediction.select("normalized.result")) | |
z.show(prediction.select("lemma.result")) | |
val evaluator = new MulticlassClassificationEvaluator() | |
.setLabelCol("label") | |
.setPredictionCol("prediction") | |
.setMetricName("accuracy") | |
var acc = evaluator.evaluate(prediction) | |
println(s"acc: $acc") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
com.johnsnowlabs.nlp:spark-nlp_2.11:jar:2.4.5