Created
November 15, 2017 01:35
-
-
Save tgreiser/d91db7553cfd3122fd4b342710416f29 to your computer and use it in GitHub Desktop.
paragraph_vector_example.scala
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/* | |
Ported from: https://github.com/deeplearning4j/dl4j-examples/blob/master/dl4j-examples/src/main/java/org/deeplearning4j/examples/nlp/paragraphvectors/ParagraphVectorsClassifierExample.java | |
NOTE - you must download the paravec resources data to /opt/data/paravec on your SKIL server. | |
https://github.com/deeplearning4j/dl4j-examples/tree/master/dl4j-examples/src/main/resources/paravec | |
*/ | |
import scala.collection.JavaConversions._ | |
import org.datavec.api.util.ClassPathResource; | |
import org.nd4j.linalg.primitives.Pair; | |
import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable; | |
import org.deeplearning4j.models.paragraphvectors.ParagraphVectors; | |
import org.deeplearning4j.models.word2vec.VocabWord; | |
import org.deeplearning4j.text.documentiterator.FileLabelAwareIterator; | |
import org.deeplearning4j.text.documentiterator.LabelAwareIterator; | |
import org.deeplearning4j.text.documentiterator.LabelledDocument; | |
import org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor; | |
import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory; | |
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory; | |
import org.nd4j.linalg.api.ndarray.INDArray; | |
import org.slf4j.Logger; | |
import org.slf4j.LoggerFactory; | |
import java.io.FileNotFoundException; | |
import java.util.concurrent.atomic.AtomicInteger; | |
val resource = new ClassPathResource("/opt/data/paravec/labeled"); | |
val iterator = new FileLabelAwareIterator.Builder().addSourceFolder(resource.getFile()).build(); | |
val tokenizerFactory = new DefaultTokenizerFactory(); | |
tokenizerFactory.setTokenPreProcessor(new CommonPreprocessor()); | |
// ParagraphVectors training configuration | |
val paragraphVectors = new ParagraphVectors.Builder() | |
.learningRate(0.025) | |
.minLearningRate(0.001) | |
.batchSize(1000) | |
.epochs(20) | |
.iterate(iterator) | |
.trainWordVectors(true) | |
.tokenizerFactory(tokenizerFactory) | |
.build(); | |
// Start model training | |
paragraphVectors.fit(); | |
def documentAsVector(lookupTable: InMemoryLookupTable[VocabWord], tokenizerFactory: TokenizerFactory, document: LabelledDocument) : INDArray = { | |
var vocabCache = lookupTable.getVocab() | |
var documentAsTokens = tokenizerFactory.create(document.getContent()).getTokens(); | |
val cnt = new AtomicInteger(0); | |
for (word <- documentAsTokens) { | |
if (vocabCache.containsWord(word)) cnt.incrementAndGet(); | |
} | |
var allWords = Nd4j.create(cnt.get(), lookupTable.layerSize()); | |
cnt.set(0); | |
for (word <- documentAsTokens) { | |
if (vocabCache.containsWord(word)) | |
allWords.putRow(cnt.getAndIncrement(), lookupTable.vector(word)); | |
} | |
return allWords.mean(0); | |
} | |
class LabelSeeker(labelsUsed: List[String], lookupTable: org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable[ | |
org.deeplearning4j.models.word2vec.VocabWord | |
]) { | |
if (labelsUsed.isEmpty) throw new IllegalStateException("You can't have 0 labels used for ParagraphVectors"); | |
def getScores(vector: org.nd4j.linalg.api.ndarray.INDArray) : Array[Tuple2[String, Double]] = { | |
val result = new Array[Tuple2[String, Double]](0) | |
for (label <- labelsUsed) { | |
val vecLabel = lookupTable.vector(label); | |
if (vecLabel == null) throw new IllegalStateException("Label '"+ label+"' has no known vector!"); | |
val sim = org.nd4j.linalg.ops.transforms.Transforms.cosineSim(vector, vecLabel); | |
result :+ (label, sim) | |
} | |
return result | |
} | |
} | |
/* | |
At this point we assume that we have model built and we can check | |
which categories our unlabeled document falls into. | |
So we'll start loading our unlabeled documents and checking them | |
*/ | |
val unClassifiedResource = new ClassPathResource("/opt/data/paravec/unlabeled") | |
val builder = new FileLabelAwareIterator.Builder().addSourceFolder(unClassifiedResource.getFile()) | |
val unClassifiedIterator = builder.build() | |
/* | |
Now we'll iterate over unlabeled data, and check which label it could be assigned to | |
Please note: for many domains it's normal to have 1 document fall into few labels at once, | |
with different "weight" for each. | |
*/ | |
var labels = iterator.getLabelsSource().getLabels().toList | |
val seeker = new LabelSeeker(labels, paragraphVectors.getLookupTable().asInstanceOf[InMemoryLookupTable[VocabWord]]); | |
unClassifiedIterator.reset() | |
while (unClassifiedIterator.hasNextDocument()) { | |
// THIS LOOP NEVER RUNS?! unClassifiedIterator has labels but apparently no documents | |
val document = unClassifiedIterator.nextDocument(); | |
val documentAsCentroid = documentAsVector(paragraphVectors.getLookupTable().asInstanceOf[InMemoryLookupTable[VocabWord]], tokenizerFactory, document); | |
val scores = seeker.getScores(documentAsCentroid); | |
/* | |
please note, document.getLabel() is used just to show which document we're looking at now, | |
as a substitute for printing out the whole document name. | |
So, labels on these two documents are used like titles, | |
just to visualize our classification done properly | |
*/ | |
for (score <- scores) { | |
print(s" " + score._1 + ": " + score._2) | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment