Created
February 4, 2016 14:00
-
-
Save Seppo420/d68e3f26ec62818ea346 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package testicle;/** | |
* Created by jp on 04/02/16. | |
*/ | |
import lombok.EqualsAndHashCode; | |
import lombok.RequiredArgsConstructor; | |
import lombok.ToString; | |
import lombok.extern.slf4j.Slf4j; | |
import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer; | |
import org.deeplearning4j.models.embeddings.wordvectors.WordVectors; | |
import org.nd4j.linalg.api.ndarray.BaseNDArray; | |
import org.nd4j.linalg.api.ndarray.INDArray; | |
import org.nd4j.linalg.factory.Nd4j; | |
import java.io.File; | |
import java.io.IOException; | |
import java.util.Collection; | |
@Slf4j | |
@RequiredArgsConstructor | |
@ToString | |
@EqualsAndHashCode | |
public class HelloWord2Vec { | |
private WordVectors wordVectors; | |
public static void main(String... args) throws Exception { | |
new HelloWord2Vec().run(); | |
} | |
private void run() throws Exception { | |
File f = new File("GoogleNews-vectors-negative300.bin"); | |
System.out.println("loading vectors... "+f.getAbsolutePath()); | |
wordVectors = WordVectorSerializer.loadGoogleModel(f, true); | |
System.out.println("vectors loaded..."); | |
/* | |
printSimilarity("cat","dog"); | |
printSimilarity("pig","dog"); | |
printSimilarity("man","woman"); | |
printSimilarity("man","dog"); | |
printNearest("putin"); | |
printNearest("finland"); | |
printNearest("penis"); | |
printNearest("cat"); | |
printNearest("marijuana"); | |
printNearest("fuck"); | |
printNearest("shit"); | |
*/ | |
printIsTo("hat","head","glove"); | |
printIsTo("man","woman","king"); | |
printIsTo("man","woman","uncle"); | |
printIsTo("cat","dog","man"); | |
printIsTo("putin","russia","obama"); | |
System.out.println("vectors loaded..."); | |
log.debug("wordVectors:{}",wordVectors); | |
} | |
private void printIsTo(String a, String b, String c) { | |
double[] av = wordVectors.getWordVector(a); | |
double[] bv = wordVectors.getWordVector(b); | |
double[] cv = wordVectors.getWordVector(c); | |
double [] ret = new double[av.length]; | |
for(int i=0;i<ret.length;i++) | |
ret[i]=cv[i]+(bv[i]-av[i]);//cv - (av-bv) = cv -av + bv = cv+(bv-av) | |
Collection<String> col = wordVectors.wordsNearest(Nd4j.create(ret),5); | |
System.out.println(a+ " is to "+b +" as "+c+ " is to ..."); | |
col.stream().filter(s->equalsNone(s,a,b,c)).forEach(System.out::println); | |
} | |
private boolean equalsNone(String s,String a, String b, String c) { | |
return !( | |
s.equalsIgnoreCase(a)|| | |
s.equalsIgnoreCase(b)|| | |
s.equalsIgnoreCase(c) | |
); | |
} | |
private void printNearest(String s) { | |
System.out.println("nearest words to: "+s); | |
wordVectors.wordsNearest(s,10).forEach(System.out::println); | |
} | |
private void printSimilarity(String a, String b) { | |
System.out.println(a+" . "+b+" = "+wordVectors.similarity(a,b)); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment