Last active
October 27, 2020 17:33
-
-
Save ErikTromp/87c6ffc11dbecf79174b74b2ea538743 to your computer and use it in GitHub Desktop.
doc2vec DL4j
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def train(dataset: List[String], labels: Option[List[String]]) = { | |
| val tokenizer: TokenizerFactory = new DefaultTokenizerFactory() | |
| val labelsUsed = collection.mutable.ListBuffer.empty[String] | |
| // Create the labeled documents, unique label for each document | |
| val docs = dataset.zipWithIndex.map(docWithIndex => { | |
| val label = labels match { | |
| case Some(lbl) => lbl(docWithIndex._2) | |
| case None => "SENT_" + docWithIndex._2 | |
| } | |
| labelsUsed += label | |
| val doc = new LabelledDocument | |
| doc.setContent(docWithIndex._1) | |
| doc.addLabel(label) | |
| doc | |
| }).asJava | |
| // Build iterator for the dataset | |
| val iterator = new SimpleLabelAwareIterator(docs) | |
| // Determine if we need to load existing word vectors or not | |
| val word2vec = WordVectorSerializer.loadGoogleModelNonNormalized(new File(w2vModelName), false, false)) | |
| val parVecs = word2vec match { | |
| case Some(w) => new ParagraphVectors.Builder() | |
| .minWordFrequency(minWordFrequency) | |
| .iterations(iterations) | |
| .epochs(epochs) | |
| .layerSize(layerSize) | |
| .learningRate(learningRate) | |
| .windowSize(windowSize) | |
| .batchSize(batchSize) | |
| .iterate(iterator) | |
| .trainWordVectors(trainWordVectors) | |
| .sampling(sampling) | |
| .tokenizerFactory(tokenizer) | |
| .useExistingWordVectors(word2vec) | |
| .build() | |
| parVecs.fit() // This is where it goes wrong | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment