Created
February 23, 2016 18:15
-
-
Save humbhenri/dec6b3a09671d76e0dc3 to your computer and use it in GitHub Desktop.
Matrix chain multiplication using dynamic programming
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package algo; | |
import java.math.BigInteger; | |
import java.util.Arrays; | |
import java.util.Random; | |
public class Matrix { | |
private int rows; | |
private int columns; | |
private BigInteger[][] array; | |
public static Matrix zero(int rows, int columns) { | |
Matrix m = new Matrix(); | |
m.rows = rows; | |
m.columns = columns; | |
m.array = new BigInteger[rows][columns]; | |
return m; | |
} | |
public static Matrix random(int rows, int columns) { | |
Random r = new Random(); | |
Matrix m = Matrix.zero(rows, columns); | |
for (int i = 0; i<m.rows; i++) | |
for (int j=0; j<m.columns; j++) | |
m.array[i][j] = BigInteger.valueOf(r.nextInt(100)); | |
return m; | |
} | |
public Matrix mul(Matrix m) { | |
if (columns != m.rows) | |
throw new IllegalArgumentException("matrix is not compatible"); | |
Matrix res = Matrix.zero(rows, m.columns); | |
for (int i = 0; i<rows; i++) { | |
for (int j=0; j<m.columns; j++) { | |
BigInteger sum = BigInteger.ZERO; | |
for (int k=0; k<columns; k++) { | |
sum = sum.add(array[i][k].multiply(m.array[k][j])); | |
} | |
res.array[i][j] = sum; | |
} | |
} | |
return res; | |
} | |
@Override | |
public String toString() { | |
return String.format("{%d x %d}", rows, columns); | |
} | |
@Override | |
public int hashCode() { | |
final int prime = 31; | |
int result = 1; | |
result = prime * result + Arrays.hashCode(array); | |
result = prime * result + columns; | |
result = prime * result + rows; | |
return result; | |
} | |
@Override | |
public boolean equals(Object obj) { | |
if (this == obj) { | |
return true; | |
} | |
if (obj == null) { | |
return false; | |
} | |
if (!(obj instanceof Matrix)) { | |
return false; | |
} | |
Matrix other = (Matrix) obj; | |
if (!Arrays.deepEquals(array, other.array)) { | |
return false; | |
} | |
if (columns != other.columns) { | |
return false; | |
} | |
if (rows != other.rows) { | |
return false; | |
} | |
return true; | |
} | |
private static Matrix timedp(Matrix[] chain) { | |
long before = System.currentTimeMillis(); | |
Matrix m = muldp(chain); | |
long after = System.currentTimeMillis(); | |
System.out.format("Time: %s ms\n", after-before); | |
return m; | |
} | |
private static Matrix muldp(Matrix[] chain) { | |
int[][] s = split(chain); | |
return muldp(chain, 0, chain.length-1, s); | |
} | |
private static Matrix muldp(Matrix[] chain, int i, int j, int[][] s) { | |
if (i == j) | |
return chain[i]; | |
int k = s[i][j]; | |
Matrix x = muldp(chain, i, k, s); | |
Matrix y = muldp(chain, k+1, j, s); | |
return x.mul(y); | |
} | |
private static int[][] split(Matrix[] chain) { | |
//Matrix Ai has dimension p[i-1] x p[i] for i = 1..n | |
int p[] = new int[chain.length+1]; | |
{ | |
int i = 0; | |
for (Matrix m : chain) { | |
p[i] = m.rows; | |
p[i+1] = m.columns; | |
i++; | |
} | |
} | |
int n = p.length -1; | |
// aux table that stores k, where k is the optimum split | |
int s[][] = new int[n][n]; | |
// m[i,j] = Minimum number of scalar multiplications (i.e., cost) | |
// needed to compute the matrix A[i]A[i+1]...A[j] = A[i..j] | |
// cost is zero when multiplying one matrix | |
int m[][] = new int[n][n]; | |
for (int ii = 1; ii < n; ii++) { | |
for (int i = 0; i < n - ii; i++) { | |
int j = i + ii; | |
m[i][j] = Integer.MAX_VALUE; | |
for (int k = i; k < j; k++) { | |
int q = m[i][k] + m[k+1][j] + p[i]*p[k+1]*p[j+1]; | |
if (q < m[i][j]) { | |
m[i][j] = q; | |
s[i][j] = k; | |
} | |
} | |
} | |
} | |
return s; | |
} | |
private static Matrix[] mkChain(int n) { | |
Random r = new Random(); | |
int p[] = new int[n+1]; | |
for (int i=0; i<p.length; i++) p[i] = 10 + r.nextInt(20); | |
Matrix chain[] = new Matrix[n]; | |
for (int i=0; i<n; i++) chain[i] = Matrix.random(p[i], p[i+1]); | |
return chain; | |
} | |
private static Matrix time(Matrix[] chain) { | |
long before = System.currentTimeMillis(); | |
Matrix m = mul(chain); | |
long after = System.currentTimeMillis(); | |
System.out.format("Time: %s ms\n", after-before); | |
return m; | |
} | |
private static Matrix mul(Matrix[] chain) { | |
Matrix res = chain[0]; | |
for (int i=1; i<chain.length; i++) { | |
res = res.mul(chain[i]); | |
} | |
return res; | |
} | |
public static void main(String[] args) { | |
Matrix[] chain = mkChain(1000); | |
Matrix m = time(chain); | |
Matrix n = timedp(chain); | |
assert m.equals(n); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment