Created
March 6, 2017 16:44
-
-
Save harperjiang/847a1f7f02fd219553761ef06154c349 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
public static void main(String[] args) { | |
fast(); | |
slow(); | |
} | |
static void fast() { | |
int hiddenDim = 200; | |
int numChar = 100; | |
int length = 500; | |
int batchSize = 50; | |
INDArray c2v = Nd4j.zeros(numChar, hiddenDim); | |
INDArray h0 = Nd4j.zeros(batchSize, hiddenDim); | |
INDArray c0 = Nd4j.zeros(batchSize, hiddenDim); | |
INDArray fwdmap = Nd4j.zeros(batchSize, numChar); | |
INDArray embed = fwdmap.mmul(c2v); | |
List<INDArray> embeds = new ArrayList<>(); | |
List<INDArray> h0s = new ArrayList<>(); | |
for (int x = 0; x < 1000; x++) { | |
embeds.add(Nd4j.createUninitialized(embed.shape())); | |
h0s.add(Nd4j.createUninitialized(h0.shape())); | |
} | |
long sum = 0; | |
for (int x = 0; x < embeds.size(); x++) { | |
long time1 = System.nanoTime(); | |
INDArray concat = Nd4j.concat(1, embeds.get(x), h0s.get(x)); | |
long time2 = System.nanoTime(); | |
sum += time2 - time1; | |
} | |
System.out.println(sum / embeds.size()); | |
} | |
static void slow() { | |
int hiddenDim = 200; | |
int numChar = 100; | |
int length = 500; | |
int batchSize = 50; | |
INDArray c2v = Nd4j.zeros(numChar, hiddenDim); | |
INDArray h0 = Nd4j.zeros(batchSize, hiddenDim); | |
INDArray c0 = Nd4j.zeros(batchSize, hiddenDim); | |
INDArray fwdmap = Nd4j.zeros(batchSize, numChar); | |
INDArray embed = fwdmap.mmul(c2v); | |
List<INDArray> embeds = new ArrayList<>(); | |
List<INDArray> h0s = new ArrayList<>(); | |
for (int x = 0; x < 1000; x++) { | |
embeds.add(Nd4j.createUninitialized(embed.shape())); | |
h0s.add(Nd4j.createUninitialized(h0.shape())); | |
} | |
long sum = 0; | |
for (int x = 0; x < embeds.size(); x++) { | |
embed = fwdmap.mmul(c2v); | |
long time1 = System.nanoTime(); | |
INDArray concat = Nd4j.concat(1, embeds.get(x), h0s.get(x)); | |
long time2 = System.nanoTime(); | |
sum += time2 - time1; | |
} | |
System.out.println(sum / embeds.size()); | |
} | |
static INDArray xavier(int[] shape) { | |
int n = 1; | |
for (int i = 0; i < shape.length - 1; i++) | |
n *= shape[i]; | |
double sd = Math.sqrt(3d / n); | |
return new UniformDistribution(-sd, sd).sample(shape); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment