Last active
June 19, 2018 11:53
-
-
Save hageldave/5b00fbd26d2c110135143f38a1b6424d to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package ezLinalg; | |
import java.io.File; | |
import java.io.IOException; | |
import java.util.ArrayList; | |
import java.util.Arrays; | |
import java.util.Scanner; | |
import java.util.function.BiFunction; | |
import java.util.function.BinaryOperator; | |
import java.util.function.DoubleBinaryOperator; | |
import java.util.function.DoubleUnaryOperator; | |
import java.util.function.Function; | |
import java.util.function.Supplier; | |
import java.util.function.UnaryOperator; | |
import hageldave.imagingkit.core.Img; | |
import hageldave.imagingkit.core.ImgBase; | |
import hageldave.imagingkit.core.Pixel; | |
import hageldave.imagingkit.core.PixelBase; | |
import hageldave.imagingkit.core.PixelConvertingSpliterator.PixelConverter; | |
import hageldave.imagingkit.core.io.ImageLoader; | |
import hageldave.imagingkit.core.scientific.ColorImg; | |
import hageldave.imagingkit.core.util.ImageFrame; | |
public class LinAlg { | |
public static interface Mat { | |
int rows(); | |
int cols(); | |
default double val(int idx){ | |
return val(idx/cols(),idx%cols()); | |
} | |
double val(int r,int c); | |
default int numValues() { | |
return rows()*cols(); | |
} | |
default int wrapCount(){ | |
return 0; | |
} | |
default double l2norm_squared() { | |
double sum = 0; | |
for(int i = 0; i < numValues(); i++) | |
sum += val(i)*val(i); | |
return sum; | |
} | |
default String shapeString(){ | |
return String.format("[%d rows %d cols]", rows(), cols()); | |
} | |
default Mat transpose(){ | |
Mat self = this; | |
return new Mat() { | |
@Override | |
public int rows() {return self.cols();} | |
@Override | |
public int cols() {return self.rows();} | |
@Override | |
public double val(int r, int c) {return self.val(c, r);} | |
@Override | |
public String toString() {return asString();} | |
@Override | |
public int wrapCount() {return self.wrapCount()+1;} | |
}; | |
} | |
default Mat negative(){ | |
Mat self = this; | |
return new Mat(){ | |
@Override | |
public int rows() {return self.rows();} | |
@Override | |
public int cols() {return self.cols();} | |
@Override | |
public double val(int idx) {return -self.val(idx);} | |
@Override | |
public double val(int r, int c) {return -self.val(r,c);} | |
@Override | |
public int wrapCount() {return self.wrapCount()+1;} | |
@Override | |
public String toString() {return asString();} | |
}; | |
} | |
default Mat asVector(){ | |
Mat self = this; | |
return new Mat(){ | |
@Override | |
public int rows() {return self.numValues();} | |
@Override | |
public int cols() {return 1;} | |
@Override | |
public double val(int r, int c) {return self.val(r);} | |
@Override | |
public int wrapCount() {return self.wrapCount()+1;} | |
@Override | |
public String toString() {return asString();} | |
}; | |
} | |
default String asString() { | |
StringBuilder sb = new StringBuilder(); | |
for(int r=0; r<rows();r++){ | |
for(int c=0; c<cols();c++){ | |
sb.append(String.format("% .4f", val(r,c))); | |
if(c<cols()-1) | |
sb.append('\t'); | |
} | |
sb.append('\n'); | |
} | |
return sb.toString(); | |
} | |
default String matlabString() { | |
StringBuilder sb = new StringBuilder(); | |
sb.append('['); | |
for(int r=0; r<rows();r++){ | |
for(int c=0; c<cols();c++){ | |
sb.append(String.format("%.4f", val(r,c))); | |
if(c<cols()-1) | |
sb.append(','); | |
} | |
if(r < rows()-1) | |
sb.append(';'); | |
} | |
sb.append(']'); | |
return sb.toString(); | |
} | |
default Matrix copy() { | |
Matrix m = new Matrix(rows(), cols()); | |
for(int r=0; r<rows();r++){ | |
for(int c=0; c<cols();c++){ | |
m.set(r, c, val(r, c)); | |
} | |
} | |
return m; | |
} | |
default Mat sliceRows(final int from, final int to){ | |
if(from > to){ | |
throw new IllegalArgumentException("from greater than to, " + from + " " + to); | |
} | |
if(from < 0){ | |
throw new IllegalArgumentException("from out of bounds, " + from + " < 0"); | |
} | |
if(to >= rows()){ | |
throw new IllegalArgumentException("to out of bounds, " + to + " > " + rows() + " (rows)"); | |
} | |
Mat self = this; | |
final int rows = to+1-from; | |
return new Mat(){ | |
@Override | |
public int rows() { | |
return rows; | |
} | |
@Override | |
public int cols() { | |
return self.cols(); | |
} | |
@Override | |
public double val(int r, int c) { | |
return self.val(r+from, c); | |
} | |
@Override | |
public int wrapCount() { | |
return self.wrapCount()+1; | |
} | |
@Override | |
public String toString() { | |
return asString(); | |
} | |
}; | |
} | |
default Mat sliceCols(final int from, final int to){ | |
if(from > to){ | |
throw new IllegalArgumentException("from greater than to, " + from + " " + to); | |
} | |
if(from < 0){ | |
throw new IllegalArgumentException("from out of bounds, " + from + " < 0"); | |
} | |
if(to >= cols()){ | |
throw new IllegalArgumentException("to out of bounds, " + to + " > " + cols() + " (cols)"); | |
} | |
Mat self = this; | |
final int cols = to+1-from; | |
return new Mat(){ | |
@Override | |
public int rows() { | |
return self.rows(); | |
} | |
@Override | |
public int cols() { | |
return cols; | |
} | |
@Override | |
public double val(int r, int c) { | |
return self.val(r, c+from); | |
} | |
@Override | |
public int wrapCount() { | |
return self.wrapCount()+1; | |
} | |
@Override | |
public String toString() { | |
return asString(); | |
} | |
}; | |
} | |
} | |
public static class Matrix implements Mat { | |
public final double[] values; | |
public final int rows; | |
public final int cols; | |
public Matrix(int rows, int cols) { | |
this(rows,cols,new double[rows*cols]); | |
} | |
public Matrix(int rows, int cols, double[] values){ | |
if(rows*cols != values.length){ | |
throw new IllegalArgumentException( | |
String.format("num values does not match dimensions, n=%d, rows=%d, cols=%d", values.length,rows,cols)); | |
} | |
this.rows=rows; | |
this.cols=cols; | |
this.values=values; | |
} | |
@Override | |
public int rows() { | |
return rows; | |
} | |
@Override | |
public int cols() { | |
return cols; | |
} | |
@Override | |
public double val(int r, int c) { | |
return val(r*cols+c); | |
} | |
@Override | |
public double val(int idx) { | |
return values[idx]; | |
} | |
public double set(int r, int c, double v){ | |
return set(r*cols+c, v); | |
} | |
public double set(int idx, double v){ | |
double prev = values[idx]; | |
values[idx] = v; | |
return prev; | |
} | |
public void set(int r, int c, Mat m){ | |
for(int row=0; row<m.rows(); row++){ | |
for(int col=0; col<m.cols(); col++){ | |
set(row+r, col+c, m.val(row, col)); | |
} | |
} | |
} | |
@Override | |
public String toString() { | |
return asString(); | |
} | |
} | |
static void requireEqual(Object o1, Object o2, Supplier<String> errmsg){ | |
if(!o1.equals(o2)){ | |
throw new IllegalArgumentException(errmsg.get()); | |
} | |
} | |
static boolean isScalar(Mat m){ | |
return m.rows() == 1 && m.cols() == 1; | |
} | |
static boolean sameSize(Mat m1, Mat m2){ | |
return m1.rows()==m2.rows() && m1.cols()==m2.cols(); | |
} | |
static boolean isSquare(Mat m){ | |
return m.rows()==m.cols(); | |
} | |
static boolean isVector(Mat m){ | |
return m.cols()==1 || m.rows()==1; | |
} | |
static boolean canMultiply(Mat m1, Mat m2){ | |
return m1.cols()==m2.rows(); | |
} | |
static Matrix plus(Mat m1, Mat m2){ | |
if(isScalar(m1)){ | |
double v = m1.val(0); | |
return plus(v, m2); | |
} | |
if(isScalar(m2)){ | |
double v = m2.val(0); | |
return plus(m1, v); | |
} | |
requireEqual(m1.rows(), m2.rows(), | |
()->String.format("cannot add matrices of unequal dimensions, m1%s m2%s",m1.shapeString(),m2.shapeString())); | |
requireEqual(m1.cols(), m2.cols(), | |
()->String.format("cannot add matrices of unequal dimensions, m1%s m2%s",m1.shapeString(),m2.shapeString())); | |
Matrix result = new Matrix(m1.rows(), m1.cols()); | |
for(int r=0; r < result.rows(); r++){ | |
for(int c=0; c < result.cols(); c++){ | |
result.set(r, c, m1.val(r,c)+m2.val(r, c)); | |
} | |
} | |
return result; | |
} | |
static Matrix plus(double scalar, Mat m){ | |
Matrix result = new Matrix(m.rows(), m.cols()); | |
for(int r=0; r < result.rows(); r++) | |
for(int c=0; c < result.cols(); c++){ | |
result.set(r,c, scalar+m.val(r,c)); | |
} | |
return result; | |
} | |
static Matrix plus(Mat m, double scalar){ | |
Matrix result = new Matrix(m.rows(), m.cols()); | |
for(int r=0; r < result.rows(); r++) | |
for(int c=0; c < result.cols(); c++){ | |
result.set(r,c, m.val(r,c)+scalar); | |
} | |
return result; | |
} | |
static Matrix minus(Mat m1, Mat m2){ | |
return plus(m1, m2.negative()); | |
} | |
static Matrix minus(double scalar, Mat m){ | |
return plus(scalar, m.negative()); | |
} | |
static Matrix minus(Mat m, double scalar){ | |
return plus(m,-scalar); | |
} | |
static Matrix mult(Mat m1, Mat m2){ | |
if(isScalar(m1)){ | |
return mult(m1.val(0),m2); | |
} | |
if(isScalar(m2)){ | |
return mult(m2.val(0),m1); | |
} | |
if(!canMultiply(m1, m2)){ | |
throw new IllegalArgumentException( | |
String.format("Cannot multiply matrices, dimensions dont match. m1%s m2%s", | |
m1.shapeString(),m2.shapeString())); | |
} | |
Matrix result = new Matrix(m1.rows(),m2.cols()); | |
for(int r=0; r<result.rows(); r++){ | |
for(int c=0; c<result.cols(); c++){ | |
double v = 0; | |
for(int k=0; k<m1.cols(); k++){ | |
v+=m1.val(r, k)*m2.val(k, c); | |
} | |
result.set(r, c, v); | |
} | |
} | |
return result; | |
} | |
static Matrix mult(double scalar, Mat m){ | |
Matrix result = new Matrix(m.rows(), m.cols()); | |
for(int i = 0; i < result.numValues();i++){ | |
result.set(i, scalar*m.val(i)); | |
} | |
return result; | |
} | |
static Matrix mult(Mat m, double scalar){ | |
return mult(scalar,m); | |
} | |
static Matrix multElementWise(Mat m1, Mat m2){ | |
requireEqual(m1.rows(), m2.rows(), | |
()->String.format("cannot elem.wise mult matrices of unequal dimensions, m1%s m2%s",m1.shapeString(),m2.shapeString())); | |
requireEqual(m1.cols(), m2.cols(), | |
()->String.format("cannot elem.wise mult matrices of unequal dimensions, m1%s m2%s",m1.shapeString(),m2.shapeString())); | |
Matrix result = new Matrix(m1.rows(), m1.cols()); | |
for(int r=0; r<result.rows(); r++) | |
for(int c=0; c<result.cols(); c++){ | |
result.set(r, c, m1.val(r, c)*m2.val(r, c)); | |
} | |
return result; | |
} | |
static Mat scalar(double scalar){ | |
return new Mat(){ | |
@Override | |
public int rows() {return 1;} | |
@Override | |
public int cols() {return 1;} | |
@Override | |
public double val(int r, int c) {return scalar;} | |
@Override | |
public String toString() {return asString();} | |
}; | |
} | |
static Matrix vector(double... values){ | |
return new Matrix(values.length, 1, values); | |
} | |
static Mat allSame(int rows, int cols, double value){ | |
return new Mat(){ | |
@Override | |
public int rows() {return rows;} | |
@Override | |
public int cols() {return cols;} | |
@Override | |
public double val(int r, int c) {return value;} | |
@Override | |
public String toString() {return asString();} | |
}; | |
} | |
static Mat zeros(int rows, int cols){ | |
return allSame(rows, cols, 0); | |
} | |
static Mat ones(int rows, int cols){ | |
return allSame(rows, cols, 1); | |
} | |
static Mat eye(int rows, int cols){ | |
return new Mat(){ | |
@Override | |
public int rows() {return rows;} | |
@Override | |
public int cols() {return cols;} | |
@Override | |
public double val(int r, int c) {return r==c? 1:0;} | |
@Override | |
public String toString() {return asString();} | |
}; | |
} | |
static Mat diag(int rows, int cols, double... diagvalues){ | |
if(rows < diagvalues.length || cols < diagvalues.length){ | |
throw new IllegalArgumentException( | |
String.format("provided more values (%d) than fit on diagonal of %dx%d", diagvalues.length,rows,cols)); | |
} | |
return new Mat(){ | |
@Override | |
public int rows() {return rows;} | |
@Override | |
public int cols() {return cols;} | |
@Override | |
public double val(int r, int c) {return r==c? diagvalues[r]:0;} | |
@Override | |
public String toString() {return asString();} | |
}; | |
} | |
static Mat elementWise(Mat m, DoubleUnaryOperator fn){ | |
Matrix result = new Matrix(m.rows(), m.cols()); | |
for(int r=0; r<result.rows(); r++) | |
for(int c=0; c<result.cols(); c++){ | |
result.set(r, c, fn.applyAsDouble(m.val(r, c))); | |
} | |
return result; | |
} | |
static double sum(Mat m){ | |
double sum = 0; | |
for(int i = 0; i < m.numValues(); i++){ | |
sum += m.val(i); | |
} | |
return sum; | |
} | |
static double min(Mat m){ | |
if(m.numValues()<1){ | |
return Double.NaN; | |
} | |
double min = m.val(0); | |
for(int i=1; i < m.numValues(); i++){ | |
min = Math.min(min, m.val(i)); | |
} | |
return min; | |
} | |
static double max(Mat m){ | |
if(m.numValues()<1){ | |
return Double.NaN; | |
} | |
double max = m.val(0); | |
for(int i=1; i < m.numValues(); i++){ | |
max = Math.max(max, m.val(i)); | |
} | |
return max; | |
} | |
static Mat linspace(double left, double right, int n){ | |
if(n < 2){ | |
throw new IllegalArgumentException("n has to be greater than 1, n="+n); | |
} | |
double step = (right-left)/(n-1); | |
return new Mat(){ | |
@Override | |
public int rows() {return n;} | |
@Override | |
public int cols() {return 1;} | |
@Override | |
public double val(int idx) {return left+idx*step;} | |
@Override | |
public double val(int r, int c) {return val(r);} | |
@Override | |
public String toString() {return asString();} | |
}; | |
} | |
static Mat reshape(Mat m, int rows, int cols){ | |
if(m.numValues() != rows*cols){ | |
throw new IllegalArgumentException("cannot reshape matrix " + m.shapeString() + " to (" + rows + " x " + cols + "), unequal number of elements"); | |
} | |
return new Mat(){ | |
@Override | |
public int rows() {return rows;} | |
@Override | |
public int cols() {return cols;} | |
@Override | |
public double val(int r, int c) {return m.val(r*cols()+c);} | |
@Override | |
public String toString() {return asString();} | |
@Override | |
public int wrapCount() {return m.wrapCount()+1;} | |
}; | |
} | |
static Mat[] meshgrid(Mat x, Mat y){ | |
if(!isVector(x)){ | |
throw new IllegalArgumentException("x is not a vector, " +x.shapeString()); | |
} | |
if(!isVector(y)){ | |
throw new IllegalArgumentException("y is not a vector, " +y.shapeString()); | |
} | |
int rows = y.numValues(); | |
int cols = x.numValues(); | |
return new Mat[]{ | |
new Mat(){ | |
@Override | |
public int rows() {return rows;} | |
@Override | |
public int cols() {return cols;} | |
@Override | |
public double val(int r, int c) {return x.val(c);} | |
@Override | |
public String toString() {return asString();} | |
@Override | |
public int wrapCount() {return x.wrapCount()+1;} | |
}, | |
new Mat(){ | |
@Override | |
public int rows() {return rows;} | |
@Override | |
public int cols() {return cols;} | |
@Override | |
public double val(int r, int c) {return y.val(r);} | |
@Override | |
public String toString() {return asString();} | |
@Override | |
public int wrapCount() {return y.wrapCount()+1;} | |
}, | |
}; | |
} | |
static Mat stackCols(Mat...mats){ | |
int rows = mats[0].rows(); | |
int cols = 0; | |
int maxWrap = 0; | |
for(int i=0; i<mats.length; i++){ | |
if(rows != mats[i].rows()){ | |
throw new IllegalArgumentException("Cannot stack columns with different number of rows"); | |
} | |
cols +=mats[i].cols(); | |
maxWrap = Math.max(maxWrap, mats[i].wrapCount()); | |
} | |
final int[] colLUT = new int[cols]; | |
final int[] colOFF = new int[cols]; | |
int currentC = 0; | |
for(int i=0; i < mats.length; i++){ | |
for(int c=0; c < mats[i].cols(); c++){ | |
colLUT[currentC+c] = i; | |
colOFF[currentC+c] = currentC; | |
} | |
currentC += mats[i].cols(); | |
} | |
final int numrows = rows; | |
final int numcols = cols; | |
final int maxWrapCount = maxWrap; | |
return new Mat(){ | |
@Override | |
public int rows() {return numrows;} | |
@Override | |
public int cols() {return numcols;} | |
@Override | |
public double val(int r, int c) { | |
int matidx = colLUT[c]; | |
int offset = colOFF[c]; | |
return mats[matidx].val(r, c-offset); | |
} | |
@Override | |
public String toString() {return asString();} | |
@Override | |
public int wrapCount() {return maxWrapCount+1;} | |
}; | |
} | |
static Mat stackRows(Mat...mats){ | |
int cols = mats[0].cols(); | |
int rows = 0; | |
int maxWrap = 0; | |
for(int i=0; i<mats.length; i++){ | |
if(cols != mats[i].cols()){ | |
throw new IllegalArgumentException("Cannot stack columns with different number of rows"); | |
} | |
rows +=mats[i].rows(); | |
maxWrap = Math.max(maxWrap, mats[i].wrapCount()); | |
} | |
final int[] rowLUT = new int[rows]; | |
final int[] rowOFF = new int[rows]; | |
int currentR = 0; | |
for(int i=0; i < mats.length; i++){ | |
for(int r=0; r < mats[i].rows(); r++){ | |
rowLUT[currentR+r] = i; | |
rowOFF[currentR+r] = currentR; | |
} | |
currentR += mats[i].rows(); | |
} | |
final int numrows = rows; | |
final int numcols = cols; | |
final int maxWrapCount = maxWrap; | |
return new Mat(){ | |
@Override | |
public int rows() {return numrows;} | |
@Override | |
public int cols() {return numcols;} | |
@Override | |
public double val(int r, int c) { | |
int matidx = rowLUT[r]; | |
int offset = rowOFF[r]; | |
return mats[matidx].val(r-offset,c); | |
} | |
@Override | |
public String toString() {return asString();} | |
@Override | |
public int wrapCount() {return maxWrapCount+1;} | |
}; | |
} | |
static Mat readFile(File f){ | |
try( | |
Scanner sc = new Scanner(f); | |
){ | |
ArrayList<Matrix> rows = new ArrayList<>(); | |
while(sc.hasNextLine()){ | |
String line = sc.nextLine(); | |
if(line.isEmpty()) | |
continue; | |
String[] splits = line.split(" "); | |
double[] array = Arrays.stream(splits).mapToDouble(Double::valueOf).toArray(); | |
rows.add(new Matrix(1, array.length, array)); | |
} | |
int nrows = rows.size(); | |
int ncols = rows.isEmpty() ? 0:rows.get(0).cols(); | |
return new Mat(){ | |
@Override | |
public int rows() {return nrows;} | |
@Override | |
public int cols() {return ncols;} | |
@Override | |
public double val(int r, int c) {return rows.get(r).val(c);} | |
@Override | |
public String toString() {return asString();} | |
}; | |
} catch (IOException e) { | |
throw new RuntimeException(e); | |
} | |
} | |
public static void main(String[] args) { | |
// doNN(); | |
Img loaded = ImageLoader.loadImgFromURL("http://fixthephoto.com/UserFiles/color-correction-before-after.jpg"); | |
Img img1 = loaded.copyArea(0, 0, loaded.getWidth()/2, loaded.getHeight()-30, null, 0, 0); | |
Img img2 = loaded.copyArea(loaded.getWidth()/2, 0, loaded.getWidth()/2, loaded.getHeight()-30, null, 0, 0); | |
UnaryOperator<Mat> transform = doNNImg(img1, img2); | |
ImageFrame.display(img1).setTitle("orig"); | |
ImageFrame.display(img2).setTitle("ytrue"); | |
PixelConverter<Pixel, Matrix> conv = new PixelConverter<Pixel, Matrix>() { | |
@Override | |
public Matrix allocateElement() { | |
return new Matrix(3, 1); | |
} | |
@Override | |
public void convertPixelToElement(Pixel px, Matrix element) { | |
element.set(0, px.r_asDouble()); | |
element.set(1, px.g_asDouble()); | |
element.set(2, px.b_asDouble()); | |
} | |
@Override | |
public void convertElementToPixel(Matrix element, Pixel px) { | |
px.setRGB_fromDouble(element.val(0), element.val(1), element.val(2)); | |
} | |
}; | |
Img img3 = img1.copy(); | |
img3.forEach(conv,true, (vec)->{Mat r = transform.apply(stackRows(scalar(1),vec)); vec.set(0, 0, r);}); | |
ImageFrame.display(img3).setTitle("reconst"); | |
} | |
static void doNN() { | |
Mat data = readFile(new File("data2Class_adjusted.txt")); | |
Mat X = data.sliceCols(0, 2); | |
Mat Y = data.sliceCols(3, 3); | |
int n = X.rows(); | |
int m = X.cols(); | |
int[] l = new int[]{m,100,1}; | |
int numlayers = l.length-1; | |
Mat[] layers = new Mat[numlayers]; | |
for(int i = 0; i < numlayers; i++){ | |
Mat layer = zeros(l[i+1], l[i]); | |
layers[i] = elementWise(layer, (v)->Math.random()*2-1); | |
} | |
Mat[] gradients = new Mat[numlayers]; | |
DoubleUnaryOperator sigmoid = (x)->1.0/(1.0+Math.exp(-x)); | |
DoubleBinaryOperator lossfn = (pred,ytrue)->Math.max(0, 1.0-pred*ytrue); | |
double totalLoss = 10; | |
int iter =0; | |
while(totalLoss > 0.01){ | |
for(int i = 0; i < numlayers; i++){ | |
gradients[i] = zeros(l[i+1], l[i]); | |
} | |
Matrix losses = new Matrix(n, 1); | |
for(int i=0; i<n; i++){ | |
Mat[] xes = new Mat[numlayers]; | |
Mat x = X.sliceRows(i, i).transpose(); | |
Mat f = fwd(x, layers, xes, sigmoid); | |
double loss = lossfn.applyAsDouble(f.val(0), Y.val(i)); | |
losses.set(i, loss); | |
Mat delta = scalar(-Y.val(i)*(1-f.val(0)*Y.val(i) > 0 ? 1:0)); | |
gradients = bwd(delta, layers, xes, gradients); | |
} | |
Mat alpha = scalar(0.05); | |
for(int i=0; i<numlayers; i++){ | |
layers[i] = minus(layers[i], mult(alpha,gradients[i])); | |
} | |
totalLoss = sum(losses)/n; | |
if((iter++) % 100 == 0) | |
System.out.println(iter + " " + totalLoss); | |
} | |
for(int i = 0; i < numlayers; i++){ | |
Mat layer = layers[i]; | |
System.out.println("layers{"+(i+1)+"} = "+layer.matlabString()+";"); | |
} | |
} | |
static UnaryOperator<Mat> doNNImg(ImgBase<? extends PixelBase> initial, ImgBase<? extends PixelBase> target) { | |
requireEqual(initial.getDimension(), target.getDimension(), ()->"images are not of same dims"); | |
Matrix data = new Matrix(initial.numValues(), 7); | |
for(PixelBase p: initial){ | |
data.set(p.getIndex(), 0, 1); | |
data.set(p.getIndex(), 1, p.r_asDouble()); | |
data.set(p.getIndex(), 2, p.g_asDouble()); | |
data.set(p.getIndex(), 3, p.b_asDouble()); | |
} | |
for(PixelBase p: target){ | |
data.set(p.getIndex(), 4, p.r_asDouble()); | |
data.set(p.getIndex(), 5, p.g_asDouble()); | |
data.set(p.getIndex(), 6, p.b_asDouble()); | |
} | |
Mat X = data.sliceCols(0, 3); | |
Mat Y = data.sliceCols(4, 6); | |
int n = X.rows(); | |
int m = X.cols(); | |
int mY = Y.cols(); | |
int[] l = new int[]{m,8,8,mY}; | |
int numlayers = l.length-1; | |
Mat[] layers = new Mat[numlayers]; | |
for(int i = 0; i < numlayers; i++){ | |
Mat layer = zeros(l[i+1], l[i]); | |
layers[i] = elementWise(layer, (v)->Math.random()*2-1); | |
} | |
Mat[] gradients = new Mat[numlayers]; | |
DoubleUnaryOperator sigmoid = (x)->1.0/(1.0+Math.exp(-x)); | |
BiFunction<Mat,Mat,Double> lossfn = (pred,ytrue)->minus(pred, ytrue).l2norm_squared(); | |
BinaryOperator<Mat> deltafn = (pred,ytrue)->mult(2,minus(pred,ytrue)).transpose(); | |
double totalLoss = 10; | |
double lossdiff = 10; | |
int iter =0; | |
while(totalLoss > 0.14){ | |
for(int i = 0; i < numlayers; i++){ | |
gradients[i] = zeros(l[i+1], l[i]); | |
} | |
Matrix losses = new Matrix(n, 1); | |
for(int i=0; i<n; i++){ | |
Mat[] xes = new Mat[numlayers]; | |
Mat x = X.sliceRows(i,i).transpose(); | |
Mat f = fwd(x, layers, xes, sigmoid); | |
Mat y = Y.sliceRows(i,i).transpose(); | |
double loss = lossfn.apply(f, y); | |
losses.set(i, loss); | |
Mat delta = deltafn.apply(f, y); //scalar(-Y.val(i)*(1-f.val(0)*Y.val(i) > 0 ? 1:0)); | |
gradients = bwd(delta, layers, xes, gradients); | |
} | |
double totallossnow = sum(losses)/n; | |
lossdiff = Math.abs(totallossnow-totalLoss); | |
totalLoss = totallossnow; | |
Mat alpha = scalar(0.05*(1.0/n)); | |
for(int i=0; i<numlayers; i++){ | |
layers[i] = minus(layers[i], mult(alpha,gradients[i])); | |
} | |
// if((iter) % 100 == 0) | |
System.out.println(iter + " " + totalLoss); | |
iter++; | |
} | |
for(int i = 0; i < numlayers; i++){ | |
Mat layer = layers[i]; | |
System.out.println("layers{"+(i+1)+"} = "+layer.matlabString()+";"); | |
} | |
for(int i = 0; i < numlayers; i++){ | |
Mat layer = layers[i]; | |
System.out.println(layer); | |
} | |
return (x)->fwd(x, layers, null, sigmoid); | |
} | |
static Mat fwd(Mat x, Mat[] layers, Mat[] xes, DoubleUnaryOperator activationFn){ | |
int numlayers = layers.length; | |
Mat z=null; | |
for(int i = 0; i < numlayers; i++){ | |
if(xes != null) | |
xes[i] = x; | |
z = mult(layers[i],x); | |
if(i < numlayers-1) | |
x = elementWise(z, activationFn); | |
} | |
return z; | |
} | |
static Mat[] bwd(Mat delta, Mat[] layers, Mat[] xes, Mat[] gradients) { | |
int numlayers = layers.length; | |
for(int i = numlayers-1; i >= 0; i--){ | |
Mat x = xes[i]; | |
Mat xmx = multElementWise(x, minus(1, x)).transpose(); | |
Mat gradient = mult(delta.transpose(),x.transpose()); | |
gradients[i] = plus(gradients[i], gradient); | |
if(i > 0) | |
delta = multElementWise(mult(delta,layers[i]),xmx); | |
} | |
return gradients; | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment