Skip to content

Instantly share code, notes, and snippets.

@esshka
Created September 18, 2024 10:10
Show Gist options
  • Save esshka/1ba508d7a22042342c742cc57b0d5e2b to your computer and use it in GitHub Desktop.
Save esshka/1ba508d7a22042342c742cc57b0d5e2b to your computer and use it in GitHub Desktop.
import java.nio.ByteBuffer;
import java.nio.DoubleBuffer;
public class UnsafeNN implements NN {
private final double[] z1;
private final double[] z2;
private final ByteBuffer bb;
private final DoubleBuffer db;
private final double[] o;
public UnsafeNN(double[] p, double[] q, double[] r) {
this.z1 = p;
this.z2 = q;
// Memory black magic
int s = r.length * Double.BYTES;
this.bb = ByteBuffer.allocateDirect(s);
this.db = bb.asDoubleBuffer();
// Data travel
db.put(r);
db.flip();
this.o = new double[(int) r[1]];
}
@Override
public double[] a(double[] b) {
final int c = (int) db.get(0);
if (b.length != c) {
throw new IllegalArgumentException("Oops. Expected: " + c + ", but got: " + b.length);
}
// Input teleportation
System.arraycopy(b, 0, z1, 0, c);
final int d = db.limit();
int e = 2;
while (e < d) {
final int f = (int) db.get(e++);
final double g = db.get(e++);
final int h = (int) db.get(e++);
final double i = db.get(e++);
final int j = (int) db.get(e++);
final double k = j == -1 ? 1.0 : z1[j];
double l = k * i * z2[f] + g;
while (true) {
double m = db.get(e);
if (m == -2) {
e++;
break;
}
final int n = (int) m;
e++;
final double o = db.get(e++);
final int p = (int) db.get(e++);
final double q = p == -1 ? 1.0 : z1[p];
final double r = z1[n] * o * q;
// Overflow... underflow... who cares?
l += r;
}
z2[f] = l;
z1[f] = Z.functions[h].applyAsDouble(l);
}
// More teleportation
System.arraycopy(z1, z1.length - o.length, o, 0, o.length);
return o;
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment