Skip to content

Instantly share code, notes, and snippets.

@angeris
Last active August 31, 2016 17:03
Show Gist options
  • Save angeris/15f5304c69e883063ca6 to your computer and use it in GitHub Desktop.
Save angeris/15f5304c69e883063ca6 to your computer and use it in GitHub Desktop.
A simple, naive Matrix class with QR factorization, back-substitution, and least squares in Java for use in Hackerrank's easier Statistics/Machine Learning problems
/*
* I threw together this set of classes to work on Hackerrank's stuff at a fairly low level. Of
* course, you are provided with several libraries to do this for you, such as Weka, which makes
* life quite a bit easier than setting up your own feature matrix. But for those who want to
* have the simple stuff such as naive QR using Gram-Schmidt, back-substitution, and school-book
* matrix multiplication, this will work. Even for fairly large datasets within Hackerrank
* the performance is fast enough to give a very wide margin of extra time. Enjoy!
*
* Guille (@Guillean)
*/
import java.util.InputMismatchException;
/*
* I used no getters/setters for the matrix and dimensions just to keep the code
* a bit cleaner and make sure its decently fast (that is, it's not really safe).
* I'm sure there are better ways of doing this, so feel free to fork and edit it as you wish.
*/
/**
* @author guillean
* Simple Matrix class with several useful operations.
*/
class Matrix {
public double a[][]; //Matrix
private int N, M; //Dimensions
/**
* Initialize a zeros matrix.
* @param dim_r Number of rows.
* @param dim_c Number of columns.
*/
public Matrix(int dim_r, int dim_c) {
N = dim_r;
M = dim_c;
a = new double[N][M];
}
/**
* Deep-copy constructor.
* @param b Matrix to copy.
*/
public Matrix(Matrix b) {
N = b.N;
M = b.M;
a = new double[N][M];
for(int i=0; i<N; i++) {
for(int j=0; j<M; j++) {
a[i][j] = b.a[i][j];
}
}
}
/**
* Initialize a matrix with values given by b[][]. Note that
* this does Not check for ragged/jagged arrays.
* @param b The array to copy into a matrix.
*/
public Matrix(double b[][]) {
N = b.length;
M = b[0].length;
a = new double[N][M];
for(int i=0; i<N; i++) {
for(int j=0; j<M; j++) {
a[i][j] = b[i][j];
}
}
}
/**
* @param i Dimensions of the unit matrix.
* @return Unit matrix (diagonal ones) of dimensions i x i.
*/
public static Matrix Unit(int i) {
Matrix out = new Matrix(i, i);
for(int j=0; j<i; j++) {
out.a[j][j] = 1;
}
return out;
}
/**
* @param b Matrix to multiply by.
* @return The product of the two matrices of dimension N x b.M
*/
public Matrix mult(Matrix b) {
if(M != b.N) throw new IllegalArgumentException("Mismatched Dimensions");
Matrix m = new Matrix(N, b.M);
for(int i=0; i<N; i++) {
for(int j=0; j<b.M; j++) {
for(int k=0; k<M; k++) {
m.a[i][j]+=a[i][k]*b.a[k][j];
}
}
}
return m;
}
/**
* Multiply the transpose of this matrix into matrix b.
* @param b The matrix to multiply by.
* @return An M x b.M matrix.
*/
public Matrix multT(Matrix b) {
if(N != b.N) throw new IllegalArgumentException("Mismatched Dimensions");
Matrix m = new Matrix(M, b.M);
for(int i=0; i<M; i++) {
for(int j=0; j<b.M; j++) {
m.a[i][j] = 0;
for(int k=0; k<N; k++) {
m.a[i][j]+=a[k][i]*b.a[k][j];
}
}
}
return m;
}
public String toString() {
String s = "[";
for(int i=0; i<N; i++) {
for(int j=0; j<M; j++) {
s += String.format("%.3e ", a[i][j]);
}
s = s.trim() +"\n ";
}
return s.trim()+"]";
}
/**
* Returns the size of the matrix through the given dimension.
* @param dim The dimension (either 0 or 1) to measure the length of.
* @return The size of matrix through dim, or -1 if out of range.
*/
public int size(int dim) {
return dim==0?N:dim==1?M:-1;
}
/**
* Writes the vector given into the specified column of the matrix.
* @param a The vector to write into.
* @param col The column index to write to.
*/
public void writeCol(Vector a, int col) {
if(a.N != N) throw new IllegalArgumentException("Column dimensions do not match");
for(int i=0; i<N; i++) {
this.a[i][col] = a.a[i];
}
}
/**
* Gives the vector of the product of this matrix and the specified vector.
* @param b Vector to multiply into.
* @return The result of the multiplication.
*/
public Vector mult(Vector b) {
if(b.N != M) throw new IllegalArgumentException("Dimensions do not match");
double total;
int i,j;
Vector out = new Vector(N);
for(i=0; i<N; i++) {
total = 0;
for(j=0; j<M; j++)
total += a[i][j]*b.a[j];
out.a[i] = total;
}
return out;
}
/**
* Gives the vector of the product of the transpose of this matrix and the specified
* vector.
* @param b Vector to multiply into.
* @return The result of the multiplication.
*/
public Vector multT(Vector b) {
if(b.N != N) throw new IllegalArgumentException("Dimensions do not match");
double total;
int i,j;
Vector out = new Vector(M);
for(i=0; i<M; i++) {
total = 0;
for(j=0; j<N; j++)
total += a[j][i]*b.a[j];
out.a[i] = total;
}
return out;
}
/**
* Back-solve an upper-triangular system with non-zero diagonal. Does Not check
* if the system is upper triangular.
* @param b Given Vector to back-solve for.
* @return Solutions vector.
*/
public Vector backSolve(Vector b) {
if(b.N != N) throw new IllegalArgumentException("Dimensions do not match");
if(N != M) throw new IllegalArgumentException("Not a square matrix");
Vector sol = new Vector(N);
int i,j;
for(i=N-1; i>=0; i--) {
sol.a[i] = b.a[i];
for(j=i+1; j<N; j++) {
sol.a[i] -= sol.a[j]*a[i][j];
}
sol.a[i] /= a[i][i];
}
return sol;
}
/**
* Runs QR factorization (school-book Gram-Schmidt) on the matrix and
* returns a QRFactorization object containing the factored matrices.
* @return QRFactorization object containing the factored matrices.
*/
public QRFactorization qr() {
QRFactorization qr = new QRFactorization(N, M);
double Q[][] = qr.Q.a;
double R[][] = qr.R.a;
double q[] = new double[N];
int i, j, k;
for(i=0; i<M; i++) {
for(j=0; j<N; j++) {
q[j] = a[j][i];
}
//Remove projections
for(j=i-1; j>=0; j--) {
R[j][i] = dot(a, Q, i, j);
for(k=0; k<N; k++) q[k] -= R[j][i]*Q[k][j];
}
R[i][i] = Math.sqrt(dot(q, q));
//Normalize
if(R[i][i]==0) throw new IllegalArgumentException("Singular matrix");
_mult(1/R[i][i], q);
//Copy
for(j=0; j<N; j++) {
Q[j][i] = q[j];
}
}
return qr;
}
/**
* Returns the dot product of two columns of given double[][] variables. Does Not check
* for jagged/ragged arrays.
* @param a Array to get column of.
* @param b Array to get column of.
* @param col_a Specifies the column of array a to compute dot product.
* @param col_b Specified the column of array b to compute dot product.
* @return The dot product of columns col_a and col_b of arrays a and b, respectively.
*/
private double dot(double a[][], double b[][], int col_a, int col_b) {
if(a.length != b.length) throw new InputMismatchException();
double total = 0;
for(int i=0; i<a.length; i++) {
total += a[i][col_a]*b[i][col_b];
}
return total;
}
/**
* Returns the dot product of arrays of similar length.
* @param a Array to get dot product of.
* @param b Array to get dot product of.
* @return Dot product of arrays a and b.
*/
private double dot(double a[], double b[]) {
if(a.length != b.length) throw new InputMismatchException();
double total = 0;
for(int i=0; i<a.length; i++)
total += a[i]*b[i];
return total;
}
private double[] mult(double alpha, double b[]) {
double t[] = new double[b.length];
for(int i=0; i<b.length; i++) {
t[i] = b[i]*alpha;
}
return t;
}
private void _mult(double alpha, double b[]) {
for(int i=0; i<b.length; i++) {
b[i] *= alpha;
}
}
}
/**
* @author guillean
* Simple class that contains two matrices and a Least
* squares solver. Additionally, should add a constrained
* least squares solver and a least-norm solver, but possibly
* will be done later.
*/
class QRFactorization {
public Matrix Q, R;
/**
* Constructor for QR matrices given an N x M matrix.
* @param N Rows of matrix to factor.
* @param M Columns of matrix to factor.
*/
public QRFactorization(int N, int M) {
Q = new Matrix(N, M);
R = new Matrix(M, M);
}
/**
* Solves the given linear least squares problem, given this as some
* feature matrix factorization. Minimizes |Q*R*x-(in)|^2.
* @param in The vector to solve for.
* @return The least squares solution to the system.
*/
public Vector solveLeastSquares(Vector in) {
return R.backSolve(Q.multT(in));
}
}
/**
* Simple Vector class.
* @author guillean
*
*/
class Vector {
public double a[];
int N;
/**
* Constructs an N-vector filled with zeros.
* @param N Size of vector.
*/
public Vector(int N) {
a = new double[N];
this.N = N;
}
/**
* Creates a deep-copy of the array arr[] into a vector
* @param arr Array to copy.
*/
public Vector(double arr[]) {
N = arr.length;
a = new double[N];
for(int i=0; i<N; i++) {
a[i] = arr[i];
}
}
/**
* Creates a deep-copy of vector vec.
* @param vec Vector to copy.
*/
public Vector(Vector vec) {
N = vec.N;
a = new double[N];
for(int i=0; i<N; i++) {
a[i] = vec.a[i];
}
}
/**
* Creates a 1 M-vector.
* @param M Size of Vector.
* @return Ones Vector of size M.
*/
public static Vector One(int M) {
return Vector.Rep(M, 1);
}
/**
* Creates a 0 M-vector.
* @param M Size of Vector.
* @return Zeros Vector of size M.
*/
public static Vector Zero(int M) {
return new Vector(M);
}
/**
* Creates an M-vector with entries val.
* @param M Size of Vector.
* @param val Value to copy.
* @return Vector of size M with entries of value val.
*/
public static Vector Rep(int M, double val) {
Vector a = new Vector(M);
if(Math.abs(val) > 1e-50) {
for(int i=0; i<M; i++) {
a.a[i] = val;
}
}
return a;
}
/**
* Returns the dot product of this with Vector b.
* @param b The vector to take the dot product with.
* @return The dot product of this with Vector b.
*/
public double dot(Vector b) {
double total = 0;
if(N != b.N) throw new IllegalArgumentException("Mismatched dimensions");
for(int i=0; i<N; i++) {
total+= a[i]*b.a[i];
}
return total;
}
public String toString() {
String s = "[";
for(int i=0; i<N; i++) {
s+=String.format("%.3e ", a[i]);
}
return s.trim()+"]";
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment