Skip to content

Instantly share code, notes, and snippets.

@Seppo420
Created February 4, 2016 14:00
Show Gist options
  • Save Seppo420/d68e3f26ec62818ea346 to your computer and use it in GitHub Desktop.
Save Seppo420/d68e3f26ec62818ea346 to your computer and use it in GitHub Desktop.
package testicle;/**
* Created by jp on 04/02/16.
*/
import lombok.EqualsAndHashCode;
import lombok.RequiredArgsConstructor;
import lombok.ToString;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer;
import org.deeplearning4j.models.embeddings.wordvectors.WordVectors;
import org.nd4j.linalg.api.ndarray.BaseNDArray;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import java.io.File;
import java.io.IOException;
import java.util.Collection;
@Slf4j
@RequiredArgsConstructor
@ToString
@EqualsAndHashCode
public class HelloWord2Vec {
private WordVectors wordVectors;
public static void main(String... args) throws Exception {
new HelloWord2Vec().run();
}
private void run() throws Exception {
File f = new File("GoogleNews-vectors-negative300.bin");
System.out.println("loading vectors... "+f.getAbsolutePath());
wordVectors = WordVectorSerializer.loadGoogleModel(f, true);
System.out.println("vectors loaded...");
/*
printSimilarity("cat","dog");
printSimilarity("pig","dog");
printSimilarity("man","woman");
printSimilarity("man","dog");
printNearest("putin");
printNearest("finland");
printNearest("penis");
printNearest("cat");
printNearest("marijuana");
printNearest("fuck");
printNearest("shit");
*/
printIsTo("hat","head","glove");
printIsTo("man","woman","king");
printIsTo("man","woman","uncle");
printIsTo("cat","dog","man");
printIsTo("putin","russia","obama");
System.out.println("vectors loaded...");
log.debug("wordVectors:{}",wordVectors);
}
private void printIsTo(String a, String b, String c) {
double[] av = wordVectors.getWordVector(a);
double[] bv = wordVectors.getWordVector(b);
double[] cv = wordVectors.getWordVector(c);
double [] ret = new double[av.length];
for(int i=0;i<ret.length;i++)
ret[i]=cv[i]+(bv[i]-av[i]);//cv - (av-bv) = cv -av + bv = cv+(bv-av)
Collection<String> col = wordVectors.wordsNearest(Nd4j.create(ret),5);
System.out.println(a+ " is to "+b +" as "+c+ " is to ...");
col.stream().filter(s->equalsNone(s,a,b,c)).forEach(System.out::println);
}
private boolean equalsNone(String s,String a, String b, String c) {
return !(
s.equalsIgnoreCase(a)||
s.equalsIgnoreCase(b)||
s.equalsIgnoreCase(c)
);
}
private void printNearest(String s) {
System.out.println("nearest words to: "+s);
wordVectors.wordsNearest(s,10).forEach(System.out::println);
}
private void printSimilarity(String a, String b) {
System.out.println(a+" . "+b+" = "+wordVectors.similarity(a,b));
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment