Skip to content

Instantly share code, notes, and snippets.

@hageldave
Last active June 19, 2018 11:53
Show Gist options
  • Save hageldave/5b00fbd26d2c110135143f38a1b6424d to your computer and use it in GitHub Desktop.
Save hageldave/5b00fbd26d2c110135143f38a1b6424d to your computer and use it in GitHub Desktop.
package ezLinalg;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Scanner;
import java.util.function.BiFunction;
import java.util.function.BinaryOperator;
import java.util.function.DoubleBinaryOperator;
import java.util.function.DoubleUnaryOperator;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.function.UnaryOperator;
import hageldave.imagingkit.core.Img;
import hageldave.imagingkit.core.ImgBase;
import hageldave.imagingkit.core.Pixel;
import hageldave.imagingkit.core.PixelBase;
import hageldave.imagingkit.core.PixelConvertingSpliterator.PixelConverter;
import hageldave.imagingkit.core.io.ImageLoader;
import hageldave.imagingkit.core.scientific.ColorImg;
import hageldave.imagingkit.core.util.ImageFrame;
public class LinAlg {
public static interface Mat {
int rows();
int cols();
default double val(int idx){
return val(idx/cols(),idx%cols());
}
double val(int r,int c);
default int numValues() {
return rows()*cols();
}
default int wrapCount(){
return 0;
}
default double l2norm_squared() {
double sum = 0;
for(int i = 0; i < numValues(); i++)
sum += val(i)*val(i);
return sum;
}
default String shapeString(){
return String.format("[%d rows %d cols]", rows(), cols());
}
default Mat transpose(){
Mat self = this;
return new Mat() {
@Override
public int rows() {return self.cols();}
@Override
public int cols() {return self.rows();}
@Override
public double val(int r, int c) {return self.val(c, r);}
@Override
public String toString() {return asString();}
@Override
public int wrapCount() {return self.wrapCount()+1;}
};
}
default Mat negative(){
Mat self = this;
return new Mat(){
@Override
public int rows() {return self.rows();}
@Override
public int cols() {return self.cols();}
@Override
public double val(int idx) {return -self.val(idx);}
@Override
public double val(int r, int c) {return -self.val(r,c);}
@Override
public int wrapCount() {return self.wrapCount()+1;}
@Override
public String toString() {return asString();}
};
}
default Mat asVector(){
Mat self = this;
return new Mat(){
@Override
public int rows() {return self.numValues();}
@Override
public int cols() {return 1;}
@Override
public double val(int r, int c) {return self.val(r);}
@Override
public int wrapCount() {return self.wrapCount()+1;}
@Override
public String toString() {return asString();}
};
}
default String asString() {
StringBuilder sb = new StringBuilder();
for(int r=0; r<rows();r++){
for(int c=0; c<cols();c++){
sb.append(String.format("% .4f", val(r,c)));
if(c<cols()-1)
sb.append('\t');
}
sb.append('\n');
}
return sb.toString();
}
default String matlabString() {
StringBuilder sb = new StringBuilder();
sb.append('[');
for(int r=0; r<rows();r++){
for(int c=0; c<cols();c++){
sb.append(String.format("%.4f", val(r,c)));
if(c<cols()-1)
sb.append(',');
}
if(r < rows()-1)
sb.append(';');
}
sb.append(']');
return sb.toString();
}
default Matrix copy() {
Matrix m = new Matrix(rows(), cols());
for(int r=0; r<rows();r++){
for(int c=0; c<cols();c++){
m.set(r, c, val(r, c));
}
}
return m;
}
default Mat sliceRows(final int from, final int to){
if(from > to){
throw new IllegalArgumentException("from greater than to, " + from + " " + to);
}
if(from < 0){
throw new IllegalArgumentException("from out of bounds, " + from + " < 0");
}
if(to >= rows()){
throw new IllegalArgumentException("to out of bounds, " + to + " > " + rows() + " (rows)");
}
Mat self = this;
final int rows = to+1-from;
return new Mat(){
@Override
public int rows() {
return rows;
}
@Override
public int cols() {
return self.cols();
}
@Override
public double val(int r, int c) {
return self.val(r+from, c);
}
@Override
public int wrapCount() {
return self.wrapCount()+1;
}
@Override
public String toString() {
return asString();
}
};
}
default Mat sliceCols(final int from, final int to){
if(from > to){
throw new IllegalArgumentException("from greater than to, " + from + " " + to);
}
if(from < 0){
throw new IllegalArgumentException("from out of bounds, " + from + " < 0");
}
if(to >= cols()){
throw new IllegalArgumentException("to out of bounds, " + to + " > " + cols() + " (cols)");
}
Mat self = this;
final int cols = to+1-from;
return new Mat(){
@Override
public int rows() {
return self.rows();
}
@Override
public int cols() {
return cols;
}
@Override
public double val(int r, int c) {
return self.val(r, c+from);
}
@Override
public int wrapCount() {
return self.wrapCount()+1;
}
@Override
public String toString() {
return asString();
}
};
}
}
public static class Matrix implements Mat {
public final double[] values;
public final int rows;
public final int cols;
public Matrix(int rows, int cols) {
this(rows,cols,new double[rows*cols]);
}
public Matrix(int rows, int cols, double[] values){
if(rows*cols != values.length){
throw new IllegalArgumentException(
String.format("num values does not match dimensions, n=%d, rows=%d, cols=%d", values.length,rows,cols));
}
this.rows=rows;
this.cols=cols;
this.values=values;
}
@Override
public int rows() {
return rows;
}
@Override
public int cols() {
return cols;
}
@Override
public double val(int r, int c) {
return val(r*cols+c);
}
@Override
public double val(int idx) {
return values[idx];
}
public double set(int r, int c, double v){
return set(r*cols+c, v);
}
public double set(int idx, double v){
double prev = values[idx];
values[idx] = v;
return prev;
}
public void set(int r, int c, Mat m){
for(int row=0; row<m.rows(); row++){
for(int col=0; col<m.cols(); col++){
set(row+r, col+c, m.val(row, col));
}
}
}
@Override
public String toString() {
return asString();
}
}
static void requireEqual(Object o1, Object o2, Supplier<String> errmsg){
if(!o1.equals(o2)){
throw new IllegalArgumentException(errmsg.get());
}
}
static boolean isScalar(Mat m){
return m.rows() == 1 && m.cols() == 1;
}
static boolean sameSize(Mat m1, Mat m2){
return m1.rows()==m2.rows() && m1.cols()==m2.cols();
}
static boolean isSquare(Mat m){
return m.rows()==m.cols();
}
static boolean isVector(Mat m){
return m.cols()==1 || m.rows()==1;
}
static boolean canMultiply(Mat m1, Mat m2){
return m1.cols()==m2.rows();
}
static Matrix plus(Mat m1, Mat m2){
if(isScalar(m1)){
double v = m1.val(0);
return plus(v, m2);
}
if(isScalar(m2)){
double v = m2.val(0);
return plus(m1, v);
}
requireEqual(m1.rows(), m2.rows(),
()->String.format("cannot add matrices of unequal dimensions, m1%s m2%s",m1.shapeString(),m2.shapeString()));
requireEqual(m1.cols(), m2.cols(),
()->String.format("cannot add matrices of unequal dimensions, m1%s m2%s",m1.shapeString(),m2.shapeString()));
Matrix result = new Matrix(m1.rows(), m1.cols());
for(int r=0; r < result.rows(); r++){
for(int c=0; c < result.cols(); c++){
result.set(r, c, m1.val(r,c)+m2.val(r, c));
}
}
return result;
}
static Matrix plus(double scalar, Mat m){
Matrix result = new Matrix(m.rows(), m.cols());
for(int r=0; r < result.rows(); r++)
for(int c=0; c < result.cols(); c++){
result.set(r,c, scalar+m.val(r,c));
}
return result;
}
static Matrix plus(Mat m, double scalar){
Matrix result = new Matrix(m.rows(), m.cols());
for(int r=0; r < result.rows(); r++)
for(int c=0; c < result.cols(); c++){
result.set(r,c, m.val(r,c)+scalar);
}
return result;
}
static Matrix minus(Mat m1, Mat m2){
return plus(m1, m2.negative());
}
static Matrix minus(double scalar, Mat m){
return plus(scalar, m.negative());
}
static Matrix minus(Mat m, double scalar){
return plus(m,-scalar);
}
static Matrix mult(Mat m1, Mat m2){
if(isScalar(m1)){
return mult(m1.val(0),m2);
}
if(isScalar(m2)){
return mult(m2.val(0),m1);
}
if(!canMultiply(m1, m2)){
throw new IllegalArgumentException(
String.format("Cannot multiply matrices, dimensions dont match. m1%s m2%s",
m1.shapeString(),m2.shapeString()));
}
Matrix result = new Matrix(m1.rows(),m2.cols());
for(int r=0; r<result.rows(); r++){
for(int c=0; c<result.cols(); c++){
double v = 0;
for(int k=0; k<m1.cols(); k++){
v+=m1.val(r, k)*m2.val(k, c);
}
result.set(r, c, v);
}
}
return result;
}
static Matrix mult(double scalar, Mat m){
Matrix result = new Matrix(m.rows(), m.cols());
for(int i = 0; i < result.numValues();i++){
result.set(i, scalar*m.val(i));
}
return result;
}
static Matrix mult(Mat m, double scalar){
return mult(scalar,m);
}
static Matrix multElementWise(Mat m1, Mat m2){
requireEqual(m1.rows(), m2.rows(),
()->String.format("cannot elem.wise mult matrices of unequal dimensions, m1%s m2%s",m1.shapeString(),m2.shapeString()));
requireEqual(m1.cols(), m2.cols(),
()->String.format("cannot elem.wise mult matrices of unequal dimensions, m1%s m2%s",m1.shapeString(),m2.shapeString()));
Matrix result = new Matrix(m1.rows(), m1.cols());
for(int r=0; r<result.rows(); r++)
for(int c=0; c<result.cols(); c++){
result.set(r, c, m1.val(r, c)*m2.val(r, c));
}
return result;
}
static Mat scalar(double scalar){
return new Mat(){
@Override
public int rows() {return 1;}
@Override
public int cols() {return 1;}
@Override
public double val(int r, int c) {return scalar;}
@Override
public String toString() {return asString();}
};
}
static Matrix vector(double... values){
return new Matrix(values.length, 1, values);
}
static Mat allSame(int rows, int cols, double value){
return new Mat(){
@Override
public int rows() {return rows;}
@Override
public int cols() {return cols;}
@Override
public double val(int r, int c) {return value;}
@Override
public String toString() {return asString();}
};
}
static Mat zeros(int rows, int cols){
return allSame(rows, cols, 0);
}
static Mat ones(int rows, int cols){
return allSame(rows, cols, 1);
}
static Mat eye(int rows, int cols){
return new Mat(){
@Override
public int rows() {return rows;}
@Override
public int cols() {return cols;}
@Override
public double val(int r, int c) {return r==c? 1:0;}
@Override
public String toString() {return asString();}
};
}
static Mat diag(int rows, int cols, double... diagvalues){
if(rows < diagvalues.length || cols < diagvalues.length){
throw new IllegalArgumentException(
String.format("provided more values (%d) than fit on diagonal of %dx%d", diagvalues.length,rows,cols));
}
return new Mat(){
@Override
public int rows() {return rows;}
@Override
public int cols() {return cols;}
@Override
public double val(int r, int c) {return r==c? diagvalues[r]:0;}
@Override
public String toString() {return asString();}
};
}
static Mat elementWise(Mat m, DoubleUnaryOperator fn){
Matrix result = new Matrix(m.rows(), m.cols());
for(int r=0; r<result.rows(); r++)
for(int c=0; c<result.cols(); c++){
result.set(r, c, fn.applyAsDouble(m.val(r, c)));
}
return result;
}
static double sum(Mat m){
double sum = 0;
for(int i = 0; i < m.numValues(); i++){
sum += m.val(i);
}
return sum;
}
static double min(Mat m){
if(m.numValues()<1){
return Double.NaN;
}
double min = m.val(0);
for(int i=1; i < m.numValues(); i++){
min = Math.min(min, m.val(i));
}
return min;
}
static double max(Mat m){
if(m.numValues()<1){
return Double.NaN;
}
double max = m.val(0);
for(int i=1; i < m.numValues(); i++){
max = Math.max(max, m.val(i));
}
return max;
}
static Mat linspace(double left, double right, int n){
if(n < 2){
throw new IllegalArgumentException("n has to be greater than 1, n="+n);
}
double step = (right-left)/(n-1);
return new Mat(){
@Override
public int rows() {return n;}
@Override
public int cols() {return 1;}
@Override
public double val(int idx) {return left+idx*step;}
@Override
public double val(int r, int c) {return val(r);}
@Override
public String toString() {return asString();}
};
}
static Mat reshape(Mat m, int rows, int cols){
if(m.numValues() != rows*cols){
throw new IllegalArgumentException("cannot reshape matrix " + m.shapeString() + " to (" + rows + " x " + cols + "), unequal number of elements");
}
return new Mat(){
@Override
public int rows() {return rows;}
@Override
public int cols() {return cols;}
@Override
public double val(int r, int c) {return m.val(r*cols()+c);}
@Override
public String toString() {return asString();}
@Override
public int wrapCount() {return m.wrapCount()+1;}
};
}
static Mat[] meshgrid(Mat x, Mat y){
if(!isVector(x)){
throw new IllegalArgumentException("x is not a vector, " +x.shapeString());
}
if(!isVector(y)){
throw new IllegalArgumentException("y is not a vector, " +y.shapeString());
}
int rows = y.numValues();
int cols = x.numValues();
return new Mat[]{
new Mat(){
@Override
public int rows() {return rows;}
@Override
public int cols() {return cols;}
@Override
public double val(int r, int c) {return x.val(c);}
@Override
public String toString() {return asString();}
@Override
public int wrapCount() {return x.wrapCount()+1;}
},
new Mat(){
@Override
public int rows() {return rows;}
@Override
public int cols() {return cols;}
@Override
public double val(int r, int c) {return y.val(r);}
@Override
public String toString() {return asString();}
@Override
public int wrapCount() {return y.wrapCount()+1;}
},
};
}
static Mat stackCols(Mat...mats){
int rows = mats[0].rows();
int cols = 0;
int maxWrap = 0;
for(int i=0; i<mats.length; i++){
if(rows != mats[i].rows()){
throw new IllegalArgumentException("Cannot stack columns with different number of rows");
}
cols +=mats[i].cols();
maxWrap = Math.max(maxWrap, mats[i].wrapCount());
}
final int[] colLUT = new int[cols];
final int[] colOFF = new int[cols];
int currentC = 0;
for(int i=0; i < mats.length; i++){
for(int c=0; c < mats[i].cols(); c++){
colLUT[currentC+c] = i;
colOFF[currentC+c] = currentC;
}
currentC += mats[i].cols();
}
final int numrows = rows;
final int numcols = cols;
final int maxWrapCount = maxWrap;
return new Mat(){
@Override
public int rows() {return numrows;}
@Override
public int cols() {return numcols;}
@Override
public double val(int r, int c) {
int matidx = colLUT[c];
int offset = colOFF[c];
return mats[matidx].val(r, c-offset);
}
@Override
public String toString() {return asString();}
@Override
public int wrapCount() {return maxWrapCount+1;}
};
}
static Mat stackRows(Mat...mats){
int cols = mats[0].cols();
int rows = 0;
int maxWrap = 0;
for(int i=0; i<mats.length; i++){
if(cols != mats[i].cols()){
throw new IllegalArgumentException("Cannot stack columns with different number of rows");
}
rows +=mats[i].rows();
maxWrap = Math.max(maxWrap, mats[i].wrapCount());
}
final int[] rowLUT = new int[rows];
final int[] rowOFF = new int[rows];
int currentR = 0;
for(int i=0; i < mats.length; i++){
for(int r=0; r < mats[i].rows(); r++){
rowLUT[currentR+r] = i;
rowOFF[currentR+r] = currentR;
}
currentR += mats[i].rows();
}
final int numrows = rows;
final int numcols = cols;
final int maxWrapCount = maxWrap;
return new Mat(){
@Override
public int rows() {return numrows;}
@Override
public int cols() {return numcols;}
@Override
public double val(int r, int c) {
int matidx = rowLUT[r];
int offset = rowOFF[r];
return mats[matidx].val(r-offset,c);
}
@Override
public String toString() {return asString();}
@Override
public int wrapCount() {return maxWrapCount+1;}
};
}
static Mat readFile(File f){
try(
Scanner sc = new Scanner(f);
){
ArrayList<Matrix> rows = new ArrayList<>();
while(sc.hasNextLine()){
String line = sc.nextLine();
if(line.isEmpty())
continue;
String[] splits = line.split(" ");
double[] array = Arrays.stream(splits).mapToDouble(Double::valueOf).toArray();
rows.add(new Matrix(1, array.length, array));
}
int nrows = rows.size();
int ncols = rows.isEmpty() ? 0:rows.get(0).cols();
return new Mat(){
@Override
public int rows() {return nrows;}
@Override
public int cols() {return ncols;}
@Override
public double val(int r, int c) {return rows.get(r).val(c);}
@Override
public String toString() {return asString();}
};
} catch (IOException e) {
throw new RuntimeException(e);
}
}
public static void main(String[] args) {
// doNN();
Img loaded = ImageLoader.loadImgFromURL("http://fixthephoto.com/UserFiles/color-correction-before-after.jpg");
Img img1 = loaded.copyArea(0, 0, loaded.getWidth()/2, loaded.getHeight()-30, null, 0, 0);
Img img2 = loaded.copyArea(loaded.getWidth()/2, 0, loaded.getWidth()/2, loaded.getHeight()-30, null, 0, 0);
UnaryOperator<Mat> transform = doNNImg(img1, img2);
ImageFrame.display(img1).setTitle("orig");
ImageFrame.display(img2).setTitle("ytrue");
PixelConverter<Pixel, Matrix> conv = new PixelConverter<Pixel, Matrix>() {
@Override
public Matrix allocateElement() {
return new Matrix(3, 1);
}
@Override
public void convertPixelToElement(Pixel px, Matrix element) {
element.set(0, px.r_asDouble());
element.set(1, px.g_asDouble());
element.set(2, px.b_asDouble());
}
@Override
public void convertElementToPixel(Matrix element, Pixel px) {
px.setRGB_fromDouble(element.val(0), element.val(1), element.val(2));
}
};
Img img3 = img1.copy();
img3.forEach(conv,true, (vec)->{Mat r = transform.apply(stackRows(scalar(1),vec)); vec.set(0, 0, r);});
ImageFrame.display(img3).setTitle("reconst");
}
static void doNN() {
Mat data = readFile(new File("data2Class_adjusted.txt"));
Mat X = data.sliceCols(0, 2);
Mat Y = data.sliceCols(3, 3);
int n = X.rows();
int m = X.cols();
int[] l = new int[]{m,100,1};
int numlayers = l.length-1;
Mat[] layers = new Mat[numlayers];
for(int i = 0; i < numlayers; i++){
Mat layer = zeros(l[i+1], l[i]);
layers[i] = elementWise(layer, (v)->Math.random()*2-1);
}
Mat[] gradients = new Mat[numlayers];
DoubleUnaryOperator sigmoid = (x)->1.0/(1.0+Math.exp(-x));
DoubleBinaryOperator lossfn = (pred,ytrue)->Math.max(0, 1.0-pred*ytrue);
double totalLoss = 10;
int iter =0;
while(totalLoss > 0.01){
for(int i = 0; i < numlayers; i++){
gradients[i] = zeros(l[i+1], l[i]);
}
Matrix losses = new Matrix(n, 1);
for(int i=0; i<n; i++){
Mat[] xes = new Mat[numlayers];
Mat x = X.sliceRows(i, i).transpose();
Mat f = fwd(x, layers, xes, sigmoid);
double loss = lossfn.applyAsDouble(f.val(0), Y.val(i));
losses.set(i, loss);
Mat delta = scalar(-Y.val(i)*(1-f.val(0)*Y.val(i) > 0 ? 1:0));
gradients = bwd(delta, layers, xes, gradients);
}
Mat alpha = scalar(0.05);
for(int i=0; i<numlayers; i++){
layers[i] = minus(layers[i], mult(alpha,gradients[i]));
}
totalLoss = sum(losses)/n;
if((iter++) % 100 == 0)
System.out.println(iter + " " + totalLoss);
}
for(int i = 0; i < numlayers; i++){
Mat layer = layers[i];
System.out.println("layers{"+(i+1)+"} = "+layer.matlabString()+";");
}
}
static UnaryOperator<Mat> doNNImg(ImgBase<? extends PixelBase> initial, ImgBase<? extends PixelBase> target) {
requireEqual(initial.getDimension(), target.getDimension(), ()->"images are not of same dims");
Matrix data = new Matrix(initial.numValues(), 7);
for(PixelBase p: initial){
data.set(p.getIndex(), 0, 1);
data.set(p.getIndex(), 1, p.r_asDouble());
data.set(p.getIndex(), 2, p.g_asDouble());
data.set(p.getIndex(), 3, p.b_asDouble());
}
for(PixelBase p: target){
data.set(p.getIndex(), 4, p.r_asDouble());
data.set(p.getIndex(), 5, p.g_asDouble());
data.set(p.getIndex(), 6, p.b_asDouble());
}
Mat X = data.sliceCols(0, 3);
Mat Y = data.sliceCols(4, 6);
int n = X.rows();
int m = X.cols();
int mY = Y.cols();
int[] l = new int[]{m,8,8,mY};
int numlayers = l.length-1;
Mat[] layers = new Mat[numlayers];
for(int i = 0; i < numlayers; i++){
Mat layer = zeros(l[i+1], l[i]);
layers[i] = elementWise(layer, (v)->Math.random()*2-1);
}
Mat[] gradients = new Mat[numlayers];
DoubleUnaryOperator sigmoid = (x)->1.0/(1.0+Math.exp(-x));
BiFunction<Mat,Mat,Double> lossfn = (pred,ytrue)->minus(pred, ytrue).l2norm_squared();
BinaryOperator<Mat> deltafn = (pred,ytrue)->mult(2,minus(pred,ytrue)).transpose();
double totalLoss = 10;
double lossdiff = 10;
int iter =0;
while(totalLoss > 0.14){
for(int i = 0; i < numlayers; i++){
gradients[i] = zeros(l[i+1], l[i]);
}
Matrix losses = new Matrix(n, 1);
for(int i=0; i<n; i++){
Mat[] xes = new Mat[numlayers];
Mat x = X.sliceRows(i,i).transpose();
Mat f = fwd(x, layers, xes, sigmoid);
Mat y = Y.sliceRows(i,i).transpose();
double loss = lossfn.apply(f, y);
losses.set(i, loss);
Mat delta = deltafn.apply(f, y); //scalar(-Y.val(i)*(1-f.val(0)*Y.val(i) > 0 ? 1:0));
gradients = bwd(delta, layers, xes, gradients);
}
double totallossnow = sum(losses)/n;
lossdiff = Math.abs(totallossnow-totalLoss);
totalLoss = totallossnow;
Mat alpha = scalar(0.05*(1.0/n));
for(int i=0; i<numlayers; i++){
layers[i] = minus(layers[i], mult(alpha,gradients[i]));
}
// if((iter) % 100 == 0)
System.out.println(iter + " " + totalLoss);
iter++;
}
for(int i = 0; i < numlayers; i++){
Mat layer = layers[i];
System.out.println("layers{"+(i+1)+"} = "+layer.matlabString()+";");
}
for(int i = 0; i < numlayers; i++){
Mat layer = layers[i];
System.out.println(layer);
}
return (x)->fwd(x, layers, null, sigmoid);
}
static Mat fwd(Mat x, Mat[] layers, Mat[] xes, DoubleUnaryOperator activationFn){
int numlayers = layers.length;
Mat z=null;
for(int i = 0; i < numlayers; i++){
if(xes != null)
xes[i] = x;
z = mult(layers[i],x);
if(i < numlayers-1)
x = elementWise(z, activationFn);
}
return z;
}
static Mat[] bwd(Mat delta, Mat[] layers, Mat[] xes, Mat[] gradients) {
int numlayers = layers.length;
for(int i = numlayers-1; i >= 0; i--){
Mat x = xes[i];
Mat xmx = multElementWise(x, minus(1, x)).transpose();
Mat gradient = mult(delta.transpose(),x.transpose());
gradients[i] = plus(gradients[i], gradient);
if(i > 0)
delta = multElementWise(mult(delta,layers[i]),xmx);
}
return gradients;
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment