Skip to content

Instantly share code, notes, and snippets.

@hageldave
Last active July 14, 2017 13:48
Show Gist options
  • Save hageldave/5f378bb01da945f8b5fe21ca421c3354 to your computer and use it in GitHub Desktop.
Save hageldave/5f378bb01da945f8b5fe21ca421c3354 to your computer and use it in GitHub Desktop.
my first implementation of a radon transformation and filtered back projection
package hageldave.imagingkit.core;
import java.awt.BasicStroke;
import java.awt.Color;
import java.util.Arrays;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Consumer;
import java.util.function.DoubleFunction;
import hageldave.imagingkit.core.Img;
import hageldave.imagingkit.core.Pixel;
import hageldave.imagingkit.core.io.ImageLoader;
import hageldave.imagingkit.core.scientific.DImg;
import hageldave.imagingkit.core.util.ImageFrame;
public class testmain {
static int NUM_BACK_PROJECTIONS = 256;
static boolean ALLOW_TRANSPARENT_RADON = false;
static final double sqrt2 = 1.414;
public static void main(String[] args) throws InterruptedException {
Img img = ImageLoader.loadImgFromURL("file:///home/haegeldd/git/Tex/seminar_cor/img/phantom_centerdot.png");
// Img img = ImageLoader.loadImgFromURL("https://upload.wikimedia.org/wikipedia/commons/e/e5/Shepp_logan.png");
// Img radonTransform = radonTransform(img, 256, false, 0.0/256);
// Img fbp = fbp(radonTransform, 256, false, true, false);
// ImageFrame.display(fbp);
// ImageFrame.display(img);
Img rad = radonTransform_fan(img, 512, 5,8);
ImageFrame.display(rad);
// int i = 0;
// Img rad = radonTransform(img, 256, true, 2*sqrt2*i/256);
// Img fbp = fbp(rad, 256, true, true, false);
// rad.forEach(px->px.setRGB(255-px.r(), 255-px.g(), 255-px.b()));
// fbp.forEach(px->px.setRGB(255-px.r(), 255-px.g(), 255-px.b()));
// rad.paint(g->
// {
// g.setColor(Color.red);
// g.setStroke(new BasicStroke(1, BasicStroke.CAP_SQUARE, BasicStroke.JOIN_MITER, 1, new float[]{4,4}, 0));
// g.drawLine(rad.getWidth()/2, 0, rad.getWidth()/2, rad.getHeight());
// });
// ImageFrame.display(fbp);
// ImageSaver.saveImage(rad.getRemoteBufferedImage(), String.format("rad_shift%02d.png",i));
// ImageSaver.saveImage(fbp.getRemoteBufferedImage(), String.format("fbp_shift%02d.png",i));
// for(int i = 1; i < 64; i*=2){
// Img rad = radonTransform(img, 256, true, 1.0*i/256);
// Img fbp = fbp(rad, 256, true, true, false);
// rad.forEach(px->px.setRGB(255-px.r(), 255-px.g(), 255-px.b()));
// fbp.forEach(px->px.setRGB(255-px.r(), 255-px.g(), 255-px.b()));
// rad.paint(g->
// {
// g.setColor(Color.red);
// g.setStroke(new BasicStroke(1, BasicStroke.CAP_SQUARE, BasicStroke.JOIN_MITER, 1, new float[]{4,4}, 0));
// g.drawLine(rad.getWidth()/2, 0, rad.getWidth()/2, rad.getHeight());
// });
// ImageSaver.saveImage(rad.getRemoteBufferedImage(), String.format("rad_shift%02d.png",i));
// ImageSaver.saveImage(fbp.getRemoteBufferedImage(), String.format("fbp_shift%02d.png",i));
// }
// for(int i = 1; i < 9; i++){
// NUM_BACK_PROJECTIONS*=2;
// fbp = fbp(radonTransform, 256, false, true, true);
// fbp.forEach(px->px.setRGB(255-px.r(), 255-px.g(), 255-px.b()));
// ImageSaver.saveImage(fbp.getRemoteBufferedImage(), String.format("fbp_%02d.png", NUM_BACK_PROJECTIONS));
// }
}
static Img radonTransform(Img img, int res){
return radonTransform(img, res, false, 0);
}
static Img radonTransform(Img img, int res, boolean full360){
return radonTransform(img, res, full360, 0);
}
static Img radonTransform(Img img, int res, boolean full360, double centerShift){
double imgDiagonal = Math.sqrt(img.getHeight()*img.getHeight()+img.getWidth()*img.getWidth())+1;
AtomicInteger count = new AtomicInteger();
Img radonImg = new Img(res, full360 ? res*2:res);
radonImg.forEach(true,px-> {
double a = px.getY()*Math.PI/res;
double r = px.getXnormalized()-0.5;
r*=2*sqrt2;
r+=centerShift;
double sumR = 0;
double sumG = 0;
double sumB = 0;
int numElements = 0;
double sin = Math.sin(a);
double cos = Math.cos(a);
double xC = r*cos;
double yC = r*sin;
for(int i = 0; i < imgDiagonal; i++){
double t = i*1.0/imgDiagonal;
t -= 0.5;
t*=2*sqrt2;
double x = xC + t*sin;
double y = yC - t*cos;
x = (x+1)/2;
y = (y+1)/2;
if(inUnitRange(x) && inUnitRange(y)){
int val = img.interpolateARGB((float) x,(float) y);
if(Pixel.a(val) > 0 || !ALLOW_TRANSPARENT_RADON){
sumR += Pixel.r_normalized(val);
sumG += Pixel.g_normalized(val);
sumB += Pixel.b_normalized(val);
numElements++;
}
}
}
if(numElements == 0 && ALLOW_TRANSPARENT_RADON){
px.setValue(0);
} else {
sumR /= Math.max(1, numElements);
sumG /= Math.max(1, numElements);
sumB /= Math.max(1, numElements);
px.setRGB_fromNormalized((float)sumR, (float)sumG, (float)sumB);
}
int prog = count.incrementAndGet();
if(prog % res == 0){
System.out.println((prog*1.0/radonImg.numValues()));
}
});
return radonImg;
}
static Img radonTransform_fan(Img img, int res, double detectorW, double focalLen){
double imgDiagonal = Math.sqrt(img.getHeight()*img.getHeight()+img.getWidth()*img.getWidth())+1;
AtomicInteger count = new AtomicInteger();
Img radonImg = new Img(res, res*2);
radonImg.forEach(true,px-> {
double a = px.getY()*Math.PI/res;
double r = px.getXnormalized()-0.5;
r*=detectorW;
double va = Math.atan(r/focalLen);
a += va;
r = Math.sin(va)*(focalLen/2);
double sumR = 0;
double sumG = 0;
double sumB = 0;
int numElements = 0;
double sin = Math.sin(a);
double cos = Math.cos(a);
double xC = r*cos;
double yC = r*sin;
for(int i = 0; i < imgDiagonal; i++){
double t = i*1.0/imgDiagonal;
t -= 0.5;
t*=2*sqrt2;
double x = xC + t*sin;
double y = yC - t*cos;
x = (x+1)/2;
y = (y+1)/2;
if(inUnitRange(x) && inUnitRange(y)){
int val = img.interpolateARGB((float) x,(float) y);
if(Pixel.a(val) > 0 || !ALLOW_TRANSPARENT_RADON){
sumR += Pixel.r_normalized(val);
sumG += Pixel.g_normalized(val);
sumB += Pixel.b_normalized(val);
numElements++;
}
}
}
if(numElements == 0 && ALLOW_TRANSPARENT_RADON){
px.setValue(0);
} else {
sumR /= Math.max(1, numElements);
sumG /= Math.max(1, numElements);
sumB /= Math.max(1, numElements);
px.setRGB_fromNormalized((float)sumR, (float)sumG, (float)sumB);
}
int prog = count.incrementAndGet();
if(prog % res == 0){
System.out.println((prog*1.0/radonImg.numValues()));
}
});
return radonImg;
}
static Img fbp(Img sinogram, int res, boolean full360){
return fbp(sinogram, res, full360, true, false);
}
static Img fbp(Img sinogram, int res, boolean full360, boolean filter, final boolean bitreveredOrder){
Img radontransform = sinogram.copy();
double numPi = full360 ? 2:1;
DImg radonFilter = new DImg(radontransform, true);
if(filter)
filter(radonFilter);
System.out.println("done filtering, now back projecting");
Img img = new Img(res, res);
img.forEach(true, px->{
double x = px.getXnormalized()-0.5;
double y = px.getYnormalized()-0.5;
x *= 2;
y *= 2;
double sumR = 0;
double sumG = 0;
double sumB = 0;
int numElements = 0;
double helperA = Math.PI*numPi/radonFilter.getHeight();
int n = radonFilter.getHeight();
for(int i = 0; i < Integer.highestOneBit(n)<<1 && i < NUM_BACK_PROJECTIONS; i++){
int iA;
if(bitreveredOrder){
iA = bitreveresedIndex(i,n);
if(iA >= n)
continue;
} else {
iA = i;
if(iA >= n)
break;
}
double a = iA*helperA;
double r = x*Math.cos(a)+y*Math.sin(a) +1;
r /= 2;
int iR = (int) (r*(radonFilter.getWidth()-1));
if(iR >= 0 && iR < radonFilter.getWidth()){
if(radonFilter.getValueA(iR, iA) > 0){
sumR += radonFilter.getValueR(iR, iA);
sumG += radonFilter.getValueG(iR, iA);
sumB += radonFilter.getValueB(iR, iA);
numElements++;
}
}
}
if(numElements == 0){
px.setValue(0);
} else {
sumR /= Math.max(1, numElements);
sumR = Math.max(0, sumR);
sumG /= Math.max(1, numElements);
sumG = Math.max(0, sumG);
sumB /= Math.max(1, numElements);
sumB = Math.max(0, sumB);
px.setRGB_fromNormalized((float)clampD(0,sumR,1), (float)clampD(0,sumG,1), (float)clampD(0,sumB,1));
}
});
return img;
}
static int bitreveresedIndex(int i, int n){
int leading = Integer.numberOfLeadingZeros(n-1);
return Integer.reverse(i)>>>leading;
}
static void filter(DImg sinogram){
for(int c = 0; c < 3; c++){
for(int row = 0; row < sinogram.getHeight(); row++){
double[][] realcomplex = {alloc(sinogram.getWidth()),alloc(sinogram.getWidth(),0)};
switch(c){
case 0: sinogram.forEach(0, row, sinogram.getWidth(), 1, px->realcomplex[0][px.getX()]=px.r());
break;
case 1: sinogram.forEach(0, row, sinogram.getWidth(), 1, px->realcomplex[0][px.getX()]=px.g());
break;
default: sinogram.forEach(0, row, sinogram.getWidth(), 1, px->realcomplex[0][px.getX()]=px.b());
}
FFT.transform(realcomplex[0], realcomplex[1]);
shift(realcomplex[0], realcomplex[0].length/2);
shift(realcomplex[1], realcomplex[1].length/2);
for(int i = 0; i < realcomplex[0].length; i++){
double x = ((i*1.0/realcomplex[0].length)-0.5)*4;
realcomplex[0][i] *= Math.abs(x)*(1-Math.abs(x))*1;
realcomplex[1][i] *= Math.abs(x)*(1-Math.abs(x))*1;
}
shift(realcomplex[0], realcomplex[0].length/2);
shift(realcomplex[1], realcomplex[1].length/2);
FFT.inverseTransform(realcomplex[0], realcomplex[1]);
// scale(realcomplex[0], 1.0/realcomplex[0].length, realcomplex[0]);
// normalize(realcomplex[0], realcomplex[0]);
int channel = c;
sinogram.forEach(0, row, sinogram.getWidth(), 1, px->{
px.setValue(channel, realcomplex[0][px.getX()]);
});
free(realcomplex[0]);
free(realcomplex[1]);
}
}
}
static class FFT {
/*
* Computes the discrete Fourier transform (DFT) of the given complex vector, storing the result back into the vector.
* The vector can have any length. This is a wrapper function.
*/
public static void transform(double[] real, double[] imag) {
if (real.length != imag.length)
throw new IllegalArgumentException("Mismatched lengths");
int n = real.length;
if (n == 0)
return;
else if ((n & (n - 1)) == 0) // Is power of 2
transformRadix2(real, imag);
else // More complicated algorithm for arbitrary sizes
transformBluestein(real, imag);
}
/*
* Computes the inverse discrete Fourier transform (IDFT) of the given complex vector, storing the result back into the vector.
* The vector can have any length. This is a wrapper function. This transform does not perform scaling, so the inverse is not a true inverse.
*/
public static void inverseTransform(double[] real, double[] imag) {
transform(imag, real);
}
/*
* Computes the discrete Fourier transform (DFT) of the given complex vector, storing the result back into the vector.
* The vector's length must be a power of 2. Uses the Cooley-Tukey decimation-in-time radix-2 algorithm.
*/
public static void transformRadix2(double[] real, double[] imag) {
// Initialization
if (real.length != imag.length)
throw new IllegalArgumentException("Mismatched lengths");
int n = real.length;
int levels = 31 - Integer.numberOfLeadingZeros(n); // Equal to floor(log2(n))
if (1 << levels != n)
throw new IllegalArgumentException("Length is not a power of 2");
double[] cosTable = alloc(n / 2);
double[] sinTable = alloc(n / 2);
for (int i = 0; i < n / 2; i++) {
cosTable[i] = Math.cos(2 * Math.PI * i / n);
sinTable[i] = Math.sin(2 * Math.PI * i / n);
}
// Bit-reversed addressing permutation
for (int i = 0; i < n; i++) {
int j = Integer.reverse(i) >>> (32 - levels);
if (j > i) {
double temp = real[i];
real[i] = real[j];
real[j] = temp;
temp = imag[i];
imag[i] = imag[j];
imag[j] = temp;
}
}
// Cooley-Tukey decimation-in-time radix-2 FFT
for (int size = 2; size <= n; size *= 2) {
int halfsize = size / 2;
int tablestep = n / size;
for (int i = 0; i < n; i += size) {
for (int j = i, k = 0; j < i + halfsize; j++, k += tablestep) {
double tpre = real[j+halfsize] * cosTable[k] + imag[j+halfsize] * sinTable[k];
double tpim = -real[j+halfsize] * sinTable[k] + imag[j+halfsize] * cosTable[k];
real[j + halfsize] = real[j] - tpre;
imag[j + halfsize] = imag[j] - tpim;
real[j] += tpre;
imag[j] += tpim;
}
}
if (size == n) // Prevent overflow in 'size *= 2'
break;
}
free(sinTable);
free(cosTable);
}
public static void transformRadix2(double[] real, double[] imag, int offset, int numElements) {
int maxIndex = offset+numElements;
// Initialization
if (real.length < maxIndex || imag.length < maxIndex)
throw new IllegalArgumentException("arrays are too short");
int n = numElements;
int levels = 31 - Integer.numberOfLeadingZeros(n); // Equal to floor(log2(n))
if (1 << levels != n)
throw new IllegalArgumentException("Length is not a power of 2");
double[] cosTable = alloc(n / 2);
double[] sinTable = alloc(n / 2);
for (int i = 0; i < n / 2; i++) {
cosTable[i] = Math.cos(2 * Math.PI * i / n);
sinTable[i] = Math.sin(2 * Math.PI * i / n);
}
// Bit-reversed addressing permutation
for (int i = 0; i < n; i++) {
int j = Integer.reverse(i) >>> (32 - levels);
if (j > i) {
double temp = real[offset+i];
real[offset+i] = real[offset+j];
real[offset+j] = temp;
temp = imag[offset+i];
imag[offset+i] = imag[offset+j];
imag[offset+j] = temp;
}
}
// Cooley-Tukey decimation-in-time radix-2 FFT
for (int size = 2; size <= n; size *= 2) {
int halfsize = size / 2;
int tablestep = n / size;
for (int i = 0; i < n; i += size) {
for (int j = i, k = 0; j < i + halfsize; j++, k += tablestep) {
double tpre = real[offset+j+halfsize] * cosTable[k] + imag[offset+j+halfsize] * sinTable[k];
double tpim = -real[offset+j+halfsize] * sinTable[k] + imag[offset+j+halfsize] * cosTable[k];
real[offset+j + halfsize] = real[offset+j] - tpre;
imag[offset+j + halfsize] = imag[offset+j] - tpim;
real[offset+j] += tpre;
imag[offset+j] += tpim;
}
}
if (size == n) // Prevent overflow in 'size *= 2'
break;
}
free(sinTable);
free(cosTable);
}
/*
* Computes the discrete Fourier transform (DFT) of the given complex vector, storing the result back into the vector.
* The vector can have any length. This requires the convolution function, which in turn requires the radix-2 FFT function.
* Uses Bluestein's chirp z-transform algorithm.
*/
public static void transformBluestein(double[] real, double[] imag) {
// Find a power-of-2 convolution length m such that m >= n * 2 + 1
if (real.length != imag.length)
throw new IllegalArgumentException("Mismatched lengths");
int n = real.length;
if (n >= 0x20000000)
throw new IllegalArgumentException("Array too large");
int m = Integer.highestOneBit(n * 2 + 1) << 1;
// Trignometric tables
double[] cosTable = alloc(n);
double[] sinTable = alloc(n);
for (int i = 0; i < n; i++) {
int j = (int)((long)i * i % (n * 2)); // This is more accurate than j = i * i
cosTable[i] = Math.cos(Math.PI * j / n);
sinTable[i] = Math.sin(Math.PI * j / n);
}
// Temporary vectors and preprocessing
double[] areal = alloc(m);
double[] aimag = alloc(m);
for (int i = 0; i < n; i++) {
areal[i] = real[i] * cosTable[i] + imag[i] * sinTable[i];
aimag[i] = -real[i] * sinTable[i] + imag[i] * cosTable[i];
}
double[] breal = alloc(m);
double[] bimag = alloc(m);
breal[0] = cosTable[0];
bimag[0] = sinTable[0];
for (int i = 1; i < n; i++) {
breal[i] = breal[m - i] = cosTable[i];
bimag[i] = bimag[m - i] = sinTable[i];
}
// Convolution
double[] creal = alloc(m);
double[] cimag = alloc(m);
convolve(areal, aimag, breal, bimag, creal, cimag);
// Postprocessing
for (int i = 0; i < n; i++) {
real[i] = creal[i] * cosTable[i] + cimag[i] * sinTable[i];
imag[i] = -creal[i] * sinTable[i] + cimag[i] * cosTable[i];
}
free(cosTable);
free(sinTable);
}
/*
* Computes the circular convolution of the given real vectors. Each vector's length must be the same.
*/
public static void convolve(double[] x, double[] y, double[] out) {
if (x.length != y.length || x.length != out.length)
throw new IllegalArgumentException("Mismatched lengths");
int n = x.length;
convolve(x, alloc(n,0), y, alloc(n,0), out, alloc(n));
}
/*
* Computes the circular convolution of the given complex vectors. Each vector's length must be the same.
*/
public static void convolve(double[] xreal, double[] ximag, double[] yreal, double[] yimag, double[] outreal, double[] outimag) {
if (xreal.length != ximag.length || xreal.length != yreal.length || yreal.length != yimag.length || xreal.length != outreal.length || outreal.length != outimag.length)
throw new IllegalArgumentException("Mismatched lengths");
int n = xreal.length;
xreal = xreal.clone();
ximag = ximag.clone();
yreal = yreal.clone();
yimag = yimag.clone();
transform(xreal, ximag);
transform(yreal, yimag);
for (int i = 0; i < n; i++) {
double temp = xreal[i] * yreal[i] - ximag[i] * yimag[i];
ximag[i] = ximag[i] * yreal[i] + xreal[i] * yimag[i];
xreal[i] = temp;
}
inverseTransform(xreal, ximag);
for (int i = 0; i < n; i++) { // Scaling (because this FFT implementation omits it)
outreal[i] = xreal[i] / n;
outimag[i] = ximag[i] / n;
}
}
}
static final class ArrayPool {
private static final ConcurrentHashMap<Integer, ArrayPool> pools = new ConcurrentHashMap<>();
private static final int capacity = 8;
private final ConcurrentLinkedQueue<double[]> stack = new ConcurrentLinkedQueue<>();
static ArrayPool get(final int size){
ArrayPool p = pools.computeIfAbsent(size, (k)->new ArrayPool());
return p;
}
private ArrayPool() {
}
public static double[] alloc(final int size) {
final ArrayPool p = get(size);
if(p.stack.isEmpty()){
return new double[size];
} else {
return p.stack.poll();
}
}
public static double[] alloc(final int size, final double fill) {
double[] array = alloc(size);
Arrays.fill(array, fill);
return array;
}
public static void free(final double[] array){
final ArrayPool p = get(array.length);
if(p.stack.size() < capacity){
p.stack.add(array);
}
}
public static double[] arrayCopy(final double[] array){
final double[] cpy = alloc(array.length);
System.arraycopy(array, 0, cpy, 0, array.length);
return cpy;
}
public static void clearCache(final int size){
final ArrayPool p = get(size);
p.stack.clear();
System.gc();
}
public static void clearCacheAll(){
for(ArrayPool p: pools.values())
p.stack.clear();
System.gc();
}
}
static void normalize(Img img){
double min = img.stream(true).mapToDouble(px->Pixel.b_normalized(px.getLuminance())).min().getAsDouble();
double max = img.stream(true).mapToDouble(px->Pixel.b_normalized(px.getLuminance())).max().getAsDouble();
img.forEach(true, px->{
double val = Pixel.b_normalized(px.getLuminance());
val = val-min/(max-min);
px.setRGB_fromNormalized((float)val, (float)val, (float)val);
});
}
static boolean inUnitRange(double d){
return d >= 0 && d <= 1;
}
static double clampD(double lo, double val, double hi){
return Math.max(lo, Math.min(val, hi));
}
static double[] alloc(int size){
return ArrayPool.alloc(size);
}
static double[] alloc(int size, double fill){
return ArrayPool.alloc(size, fill);
}
static void free(double[] arr){
ArrayPool.free(arr);
}
public static double[] normalize(double[] array, double[] result){
if(result == null)
result = alloc(array.length);
if(array != result){
System.arraycopy(array, 0, result, 0, array.length);
array = result;
}
double min = min(array);
double range = max(array) - min;
for(int i = 0; i < array.length; i++)
array[i] = (array[i]-min)/range;
return array;
}
public static double max(double[] array){
double max = array[0];
for(double d: array)
max = Math.max(d, max);
return max;
}
public static double min(double[] array){
double min = array[0];
for(double d: array)
min = Math.min(d, min);
return min;
}
public static double[] scale(double[] array, double factor, double[] result){
if(result == null)
result = alloc(array.length);
if(array != result){
System.arraycopy(array, 0, result, 0, array.length);
array = result;
}
for(int i = 0; i < array.length; i++)
array[i] *= factor;
return array;
}
public static double mean(double[] array){
double sum = 0;
for(double d: array)
sum += d;
return sum/array.length;
}
public static void complexMult(double[] realA, double[] imagA, double[] realB, double[] imagB, double[] realR, double[] imagR){
if(realR == null)
realR = alloc(realA.length);
if(realA != realR){
System.arraycopy(realA, 0, realR, 0, realA.length);
realA = realR;
}
if(imagR == null)
imagR = alloc(imagA.length);
if(imagA != realR){
System.arraycopy(imagA, 0, imagR, 0, imagA.length);
imagA = imagR;
}
for(int i = 0; i < realA.length; i++){
double a = realA[i];
double bi = imagA[i];
double c = realB[i];
double di = imagB[i];
realR[i] = a*c - bi*di;
imagR[i] = a*di + bi*c;
}
}
public static void shift(double[] array, int shift){
shift(array, shift, 0, array.length, array, array);
}
public static double[] shift(double[] array, int shift, int offset, int size, double[] result, double[] cpy){
if(result == null)
result = alloc(array.length);
if(array != result){
System.arraycopy(array, 0, result, 0, array.length);
array = result;
}
boolean tempCpy = false;
if(cpy == null || cpy == array){
cpy = ArrayPool.arrayCopy(array);
tempCpy = true;
}
for(int i = 0; i < size; i++){
array[(offset+((i+shift)%size))%array.length] = cpy[(offset+i)%array.length];
}
if(tempCpy)
free(cpy);
return array;
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment