Created
May 26, 2013 08:03
-
-
Save thomasjungblut/5652037 to your computer and use it in GitHub Desktop.
Math Library Benchmark: GPU vs. JBLAS vs. pure Java
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package de.jungblut.benchmark; | |
import java.util.Random; | |
import com.google.caliper.Param; | |
import com.google.caliper.Runner; | |
import com.google.caliper.SimpleBenchmark; | |
import de.jungblut.math.DoubleMatrix; | |
import de.jungblut.math.cuda.JCUDAMatrixUtils; | |
import de.jungblut.math.dense.DenseDoubleMatrix; | |
public class MathLibBenchmark extends SimpleBenchmark { | |
@Param({ "10", "20", "40", "60", "80", "100", "500", "1000", "2000" }) | |
private int n; | |
@Param | |
CalcType type; | |
private DenseDoubleMatrix mat; | |
private DenseDoubleMatrix mat2; | |
public enum CalcType { | |
GPU, JBLAS, TJ_MATH | |
}; | |
@Override | |
protected void setUp() throws Exception { | |
mat = new DenseDoubleMatrix(n, n, new Random()); | |
mat2 = new DenseDoubleMatrix(n, n, new Random()); | |
} | |
public void timeCalculate(int reps) { | |
for (int rep = 0; rep < reps; rep++) { | |
int sum = 0; | |
switch (type) { | |
case JBLAS: | |
sum = jblas(sum); | |
break; | |
case TJ_MATH: | |
sum = tjmath(sum); | |
break; | |
case GPU: | |
sum = gpu(sum); | |
break; | |
default: | |
break; | |
} | |
System.out.println(sum); | |
} | |
} | |
private int gpu(int sum) { | |
DenseDoubleMatrix multiply2 = JCUDAMatrixUtils.multiply(mat, mat2); | |
sum += multiply2.getColumnCount(); | |
return sum; | |
} | |
private int tjmath(int sum) { | |
DoubleMatrix multiply = mat.multiply(mat2); | |
sum += multiply.getRowCount(); | |
return sum; | |
} | |
private int jblas(int sum) { | |
org.jblas.DoubleMatrix jblasThis = new org.jblas.DoubleMatrix(mat.toArray()); | |
org.jblas.DoubleMatrix jblasOther = new org.jblas.DoubleMatrix( | |
mat2.toArray()); | |
org.jblas.DoubleMatrix jblasRes = new org.jblas.DoubleMatrix(n, n); | |
jblasThis.mmuli(jblasOther, jblasRes); | |
sum += jblasRes.columns; | |
return sum; | |
} | |
public static void main(String[] args) { | |
Runner.main(MathLibBenchmark.class, args); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment