Skip to content

Instantly share code, notes, and snippets.

@lacic
Created November 21, 2018 15:00
Show Gist options
  • Select an option

  • Save lacic/14aaa02e58cf11ff71142746b811c858 to your computer and use it in GitHub Desktop.

Select an option

Save lacic/14aaa02e58cf11ff71142746b811c858 to your computer and use it in GitHub Desktop.
package service.deeplearning;
import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer;
import org.deeplearning4j.models.paragraphvectors.ParagraphVectors;
import org.deeplearning4j.models.word2vec.VocabWord;
import org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache;
import org.deeplearning4j.text.documentiterator.BasicLabelAwareIterator;
import org.deeplearning4j.text.sentenceiterator.CollectionSentenceIterator;
import org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor;
import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory;
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
import org.junit.Test;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.ops.transforms.Transforms;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotEquals;
public class SaveBug {
private ParagraphVectors train(int negativeSamples) {
// parameters
int windowSize = 20;
double learningRate = 0.025;
int dimensions = 100;
int epochs = 1;
int minWordFrequency = 1;
// data
List<String> data = new ArrayList<String>() {{
add("common1 unused1 common2 common3");
add("common1 common3 double double");
add("common2 another1");
}};
// build model
CollectionSentenceIterator sentenceIterator = new CollectionSentenceIterator(data);
BasicLabelAwareIterator basicLabelAwareIterator = new BasicLabelAwareIterator.Builder(sentenceIterator).build();
AbstractCache<VocabWord> cache = new AbstractCache<>();
TokenizerFactory t = new DefaultTokenizerFactory();
t.setTokenPreProcessor(new CommonPreprocessor());
ParagraphVectors.Builder vecBuilder = new ParagraphVectors.Builder()
.minWordFrequency(minWordFrequency)
.iterations(15)
.epochs(epochs)
.layerSize(dimensions)
.learningRate(learningRate)
.labelsSource(basicLabelAwareIterator.getLabelsSource())
.windowSize(windowSize)
.iterate(basicLabelAwareIterator)
.trainWordVectors(true)
.vocabCache(cache)
.tokenizerFactory(t)
.sampling(0);
// Critical path
if (negativeSamples > 0) {
vecBuilder
.negativeSample(negativeSamples)
.useHierarchicSoftmax(false);
}
// train
ParagraphVectors vec = vecBuilder.build();
vec.fit();
return vec;
}
private ParagraphVectors storeAndReload(ParagraphVectors vec) throws IOException {
// I/O
String modelName = "example.pv";
String modelPath = "./models";
File storedModel = new File(modelPath, modelName);
if (storedModel.exists()) {
storedModel.delete();
}
// store model
WordVectorSerializer.writeParagraphVectors(vec, storedModel);
// load model
ParagraphVectors vecNew = WordVectorSerializer.readParagraphVectors(storedModel);
TokenizerFactory t2 = new DefaultTokenizerFactory();
t2.setTokenPreProcessor(new CommonPreprocessor());
vecNew.setTokenizerFactory(t2);
vecNew.getConfiguration().setIterations(15);
return vecNew;
}
private void testStoringEquals(int negativeSamples) throws IOException {
ParagraphVectors vec = train(negativeSamples);
// infer
String testString = "test common1 common2 common3";
INDArray vectorBeforeStoring = vec.inferVector(testString);
System.out.println(vectorBeforeStoring);
// reload
ParagraphVectors vecNew = storeAndReload(vec);
// test vectors are the same
INDArray vectorAfterStoring = vecNew.inferVector(testString);
System.out.println(Transforms.cosineSim(vectorBeforeStoring, vectorAfterStoring));
assertEquals(vectorBeforeStoring, vectorAfterStoring);
}
private void testStoringDifferent(int negativeSamples) throws IOException {
ParagraphVectors vec = train(negativeSamples);
// reload from file
ParagraphVectors vecNew = storeAndReload(vec);
// just to ensure not all the vectors are the same
String testString = "test common1 common2 common3";
INDArray vectorAfterStoring = vecNew.inferVector(testString);
String anotherString = "test another1";
INDArray anotherVector = vecNew.inferVector(anotherString);
assertNotEquals(vectorAfterStoring, anotherVector);
}
// This test case is working as expecting, when no negative sampling
@Test
public void withoutNegativeSampling() throws IOException {
testStoringEquals(0);
testStoringDifferent(0);
}
// The below test cases are broken
@Test // this test is off by just a bit
public void withOneNegativeSampleEquals() throws IOException {
testStoringEquals(1);
}
@Test
public void withOneNegativeSampleDifferent() throws IOException {
testStoringDifferent(1);
}
@Test // this test is off by quite a bit
public void withManyNegativeSamplesEquals() throws IOException {
testStoringEquals(20);
}
@Test
public void withManyNegativeSamplesDifferent() throws IOException {
testStoringDifferent(20);
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment