Created
November 21, 2018 15:00
-
-
Save lacic/14aaa02e58cf11ff71142746b811c858 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| package service.deeplearning; | |
| import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer; | |
| import org.deeplearning4j.models.paragraphvectors.ParagraphVectors; | |
| import org.deeplearning4j.models.word2vec.VocabWord; | |
| import org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache; | |
| import org.deeplearning4j.text.documentiterator.BasicLabelAwareIterator; | |
| import org.deeplearning4j.text.sentenceiterator.CollectionSentenceIterator; | |
| import org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor; | |
| import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory; | |
| import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory; | |
| import org.junit.Test; | |
| import org.nd4j.linalg.api.ndarray.INDArray; | |
| import org.nd4j.linalg.ops.transforms.Transforms; | |
| import java.io.File; | |
| import java.io.IOException; | |
| import java.util.ArrayList; | |
| import java.util.List; | |
| import static org.junit.Assert.assertEquals; | |
| import static org.junit.Assert.assertNotEquals; | |
| public class SaveBug { | |
| private ParagraphVectors train(int negativeSamples) { | |
| // parameters | |
| int windowSize = 20; | |
| double learningRate = 0.025; | |
| int dimensions = 100; | |
| int epochs = 1; | |
| int minWordFrequency = 1; | |
| // data | |
| List<String> data = new ArrayList<String>() {{ | |
| add("common1 unused1 common2 common3"); | |
| add("common1 common3 double double"); | |
| add("common2 another1"); | |
| }}; | |
| // build model | |
| CollectionSentenceIterator sentenceIterator = new CollectionSentenceIterator(data); | |
| BasicLabelAwareIterator basicLabelAwareIterator = new BasicLabelAwareIterator.Builder(sentenceIterator).build(); | |
| AbstractCache<VocabWord> cache = new AbstractCache<>(); | |
| TokenizerFactory t = new DefaultTokenizerFactory(); | |
| t.setTokenPreProcessor(new CommonPreprocessor()); | |
| ParagraphVectors.Builder vecBuilder = new ParagraphVectors.Builder() | |
| .minWordFrequency(minWordFrequency) | |
| .iterations(15) | |
| .epochs(epochs) | |
| .layerSize(dimensions) | |
| .learningRate(learningRate) | |
| .labelsSource(basicLabelAwareIterator.getLabelsSource()) | |
| .windowSize(windowSize) | |
| .iterate(basicLabelAwareIterator) | |
| .trainWordVectors(true) | |
| .vocabCache(cache) | |
| .tokenizerFactory(t) | |
| .sampling(0); | |
| // Critical path | |
| if (negativeSamples > 0) { | |
| vecBuilder | |
| .negativeSample(negativeSamples) | |
| .useHierarchicSoftmax(false); | |
| } | |
| // train | |
| ParagraphVectors vec = vecBuilder.build(); | |
| vec.fit(); | |
| return vec; | |
| } | |
| private ParagraphVectors storeAndReload(ParagraphVectors vec) throws IOException { | |
| // I/O | |
| String modelName = "example.pv"; | |
| String modelPath = "./models"; | |
| File storedModel = new File(modelPath, modelName); | |
| if (storedModel.exists()) { | |
| storedModel.delete(); | |
| } | |
| // store model | |
| WordVectorSerializer.writeParagraphVectors(vec, storedModel); | |
| // load model | |
| ParagraphVectors vecNew = WordVectorSerializer.readParagraphVectors(storedModel); | |
| TokenizerFactory t2 = new DefaultTokenizerFactory(); | |
| t2.setTokenPreProcessor(new CommonPreprocessor()); | |
| vecNew.setTokenizerFactory(t2); | |
| vecNew.getConfiguration().setIterations(15); | |
| return vecNew; | |
| } | |
| private void testStoringEquals(int negativeSamples) throws IOException { | |
| ParagraphVectors vec = train(negativeSamples); | |
| // infer | |
| String testString = "test common1 common2 common3"; | |
| INDArray vectorBeforeStoring = vec.inferVector(testString); | |
| System.out.println(vectorBeforeStoring); | |
| // reload | |
| ParagraphVectors vecNew = storeAndReload(vec); | |
| // test vectors are the same | |
| INDArray vectorAfterStoring = vecNew.inferVector(testString); | |
| System.out.println(Transforms.cosineSim(vectorBeforeStoring, vectorAfterStoring)); | |
| assertEquals(vectorBeforeStoring, vectorAfterStoring); | |
| } | |
| private void testStoringDifferent(int negativeSamples) throws IOException { | |
| ParagraphVectors vec = train(negativeSamples); | |
| // reload from file | |
| ParagraphVectors vecNew = storeAndReload(vec); | |
| // just to ensure not all the vectors are the same | |
| String testString = "test common1 common2 common3"; | |
| INDArray vectorAfterStoring = vecNew.inferVector(testString); | |
| String anotherString = "test another1"; | |
| INDArray anotherVector = vecNew.inferVector(anotherString); | |
| assertNotEquals(vectorAfterStoring, anotherVector); | |
| } | |
| // This test case is working as expecting, when no negative sampling | |
| @Test | |
| public void withoutNegativeSampling() throws IOException { | |
| testStoringEquals(0); | |
| testStoringDifferent(0); | |
| } | |
| // The below test cases are broken | |
| @Test // this test is off by just a bit | |
| public void withOneNegativeSampleEquals() throws IOException { | |
| testStoringEquals(1); | |
| } | |
| @Test | |
| public void withOneNegativeSampleDifferent() throws IOException { | |
| testStoringDifferent(1); | |
| } | |
| @Test // this test is off by quite a bit | |
| public void withManyNegativeSamplesEquals() throws IOException { | |
| testStoringEquals(20); | |
| } | |
| @Test | |
| public void withManyNegativeSamplesDifferent() throws IOException { | |
| testStoringDifferent(20); | |
| } | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment