Skip to content

Instantly share code, notes, and snippets.

@ErikTromp
Last active October 27, 2020 17:33
Show Gist options
  • Select an option

  • Save ErikTromp/87c6ffc11dbecf79174b74b2ea538743 to your computer and use it in GitHub Desktop.

Select an option

Save ErikTromp/87c6ffc11dbecf79174b74b2ea538743 to your computer and use it in GitHub Desktop.
doc2vec DL4j
def train(dataset: List[String], labels: Option[List[String]]) = {
val tokenizer: TokenizerFactory = new DefaultTokenizerFactory()
val labelsUsed = collection.mutable.ListBuffer.empty[String]
// Create the labeled documents, unique label for each document
val docs = dataset.zipWithIndex.map(docWithIndex => {
val label = labels match {
case Some(lbl) => lbl(docWithIndex._2)
case None => "SENT_" + docWithIndex._2
}
labelsUsed += label
val doc = new LabelledDocument
doc.setContent(docWithIndex._1)
doc.addLabel(label)
doc
}).asJava
// Build iterator for the dataset
val iterator = new SimpleLabelAwareIterator(docs)
// Determine if we need to load existing word vectors or not
val word2vec = WordVectorSerializer.loadGoogleModelNonNormalized(new File(w2vModelName), false, false))
val parVecs = word2vec match {
case Some(w) => new ParagraphVectors.Builder()
.minWordFrequency(minWordFrequency)
.iterations(iterations)
.epochs(epochs)
.layerSize(layerSize)
.learningRate(learningRate)
.windowSize(windowSize)
.batchSize(batchSize)
.iterate(iterator)
.trainWordVectors(trainWordVectors)
.sampling(sampling)
.tokenizerFactory(tokenizer)
.useExistingWordVectors(word2vec)
.build()
parVecs.fit() // This is where it goes wrong
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment