Skip to content

Instantly share code, notes, and snippets.

@zboralski
Created November 13, 2025 16:15
Show Gist options
  • Select an option

  • Save zboralski/f69e69432291c367fe257a854870ba23 to your computer and use it in GitHub Desktop.

Select an option

Save zboralski/f69e69432291c367fe257a854870ba23 to your computer and use it in GitHub Desktop.
QuantumBaristaEJML
import org.ejml.simple.SimpleMatrix;
import javax.crypto.Mac;
import javax.crypto.spec.SecretKeySpec;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.security.MessageDigest;
/**
* Quantum Barista - Compact RHF solver using EJML
* Optimized pure Java port for Android
*
* Dependencies:
* implementation 'org.ejml:ejml-simple:0.43.1'
*/
public class QuantumBaristaEJML {
// STO-3G basis for hydrogen
private static final double[] H_EXPS = {3.42525091, 0.62391373, 0.16885540};
private static final double[] H_COEFS = {0.15432897, 0.53532814, 0.44463454};
// Boys F0(x) function with small-x series expansion
private static double boys0(double x) {
if (x < 1e-6) {
// F0(x) = 1 - x/3 + x^2/10 - O(x^3)
double x2 = x * x;
return 1.0 - x / 3.0 + 0.1 * x2;
}
double sqrtX = Math.sqrt(x);
return 0.5 * Math.sqrt(Math.PI / x) * erf(sqrtX);
}
// Error function approximation (Abramowitz and Stegun)
private static double erf(double x) {
// Save sign
int sign = (x >= 0) ? 1 : -1;
x = Math.abs(x);
// Constants
double a1 = 0.254829592;
double a2 = -0.284496736;
double a3 = 1.421413741;
double a4 = -1.453152027;
double a5 = 1.061405429;
double p = 0.3275911;
double t = 1.0 / (1.0 + p * x);
double y = 1.0 - (((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t * Math.exp(-x * x);
return sign * y;
}
// Gaussian product center
private static class GaussianProduct {
final double p;
final double[] P;
final double K;
final double AB2;
GaussianProduct(double a, double[] A, double b, double[] B) {
this.p = a + b;
this.P = new double[3];
double ab2 = 0.0;
for (int i = 0; i < 3; i++) {
P[i] = (a * A[i] + b * B[i]) / p;
double d = A[i] - B[i];
ab2 += d * d;
}
this.AB2 = ab2;
this.K = Math.exp(-a * b * ab2 / p);
}
}
// s-s overlap integral (inlined for performance - no GaussianProduct allocation)
private static double S_ss(double a, double[] A, double b, double[] B) {
double p = a + b;
double ab2 = 0.0;
for (int i = 0; i < 3; i++) {
double d = A[i] - B[i];
ab2 += d * d;
}
double K = Math.exp(-a * b * ab2 / p);
return Math.pow(Math.PI / p, 1.5) * K;
}
// s-s kinetic energy integral (inlined for performance)
private static double T_ss(double a, double[] A, double b, double[] B) {
double p = a + b;
double ab2 = 0.0;
for (int i = 0; i < 3; i++) {
double d = A[i] - B[i];
ab2 += d * d;
}
double K = Math.exp(-a * b * ab2 / p);
return K * Math.pow(Math.PI / p, 1.5) * (a * b / p) *
(3.0 - 2.0 * a * b * ab2 / p);
}
// s-s nuclear attraction integral (inlined for performance)
private static double V_ss(double a, double[] A, double b, double[] B,
double[] C, int Z) {
double p = a + b;
double ab2 = 0.0;
double[] P = new double[3];
for (int i = 0; i < 3; i++) {
double d = A[i] - B[i];
ab2 += d * d;
P[i] = (a * A[i] + b * B[i]) / p;
}
double K = Math.exp(-a * b * ab2 / p);
double PC2 = 0.0;
for (int i = 0; i < 3; i++) {
double d = P[i] - C[i];
PC2 += d * d;
}
return -Z * (2.0 * Math.PI / p) * K * boys0(p * PC2);
}
// (ss|ss) electron repulsion integral
private static double eri_ssss(double a, double[] A, double b, double[] B,
double c, double[] C, double d, double[] D) {
GaussianProduct gp1 = new GaussianProduct(a, A, b, B);
GaussianProduct gp2 = new GaussianProduct(c, C, d, D);
double PQ2 = 0.0;
for (int i = 0; i < 3; i++) {
double diff = gp1.P[i] - gp2.P[i];
PQ2 += diff * diff;
}
return 2.0 * Math.pow(Math.PI, 2.5) * gp1.K * gp2.K /
(gp1.p * gp2.p * Math.sqrt(gp1.p + gp2.p)) *
boys0(gp1.p * gp2.p * PQ2 / (gp1.p + gp2.p));
}
// Atom class (represents a nucleus)
private static class Atom {
final int Z;
final double[] R;
Atom(int Z, double[] R) {
this.Z = Z;
this.R = R.clone(); // Clone to avoid external mutation
}
}
// Shell class (represents a basis function)
private static class Shell {
final double[] R;
final double[] exps;
final double[] coefs;
Shell(double[] R) {
this.R = R.clone(); // Clone to avoid external mutation
this.exps = H_EXPS.clone();
this.coefs = H_COEFS.clone();
}
}
// Contract generic 1e integral
private static double contr_1e(Shell shA, Shell shB, IntegralFunction f) {
double out = 0.0;
for (int i = 0; i < shA.exps.length; i++) {
for (int j = 0; j < shB.exps.length; j++) {
out += shA.coefs[i] * shB.coefs[j] *
f.eval(shA.exps[i], shA.R, shB.exps[j], shB.R);
}
}
return out;
}
@FunctionalInterface
interface IntegralFunction {
double eval(double a, double[] A, double b, double[] B);
}
// Contract nuclear attraction
private static double contr_V(Shell shA, Shell shB, Atom[] atoms) {
double out = 0.0;
for (Atom atom : atoms) {
for (int i = 0; i < shA.exps.length; i++) {
for (int j = 0; j < shB.exps.length; j++) {
out += shA.coefs[i] * shB.coefs[j] *
V_ss(shA.exps[i], shA.R, shB.exps[j], shB.R,
atom.R, atom.Z);
}
}
}
return out;
}
// Contract ERI (with GaussianProduct caching for hot path optimization)
private static double contr_eri(Shell A, Shell B, Shell C, Shell D) {
double out = 0.0;
int na = A.exps.length;
int nb = B.exps.length;
int nc = C.exps.length;
int nd = D.exps.length;
// Pre-compute GaussianProducts for AB and CD pairs
GaussianProduct[][] gpAB = new GaussianProduct[na][nb];
GaussianProduct[][] gpCD = new GaussianProduct[nc][nd];
for (int i = 0; i < na; i++) {
for (int j = 0; j < nb; j++) {
gpAB[i][j] = new GaussianProduct(A.exps[i], A.R, B.exps[j], B.R);
}
}
for (int k = 0; k < nc; k++) {
for (int l = 0; l < nd; l++) {
gpCD[k][l] = new GaussianProduct(C.exps[k], C.R, D.exps[l], D.R);
}
}
// Compute ERI using cached GaussianProducts
for (int i = 0; i < na; i++) {
for (int j = 0; j < nb; j++) {
for (int k = 0; k < nc; k++) {
for (int l = 0; l < nd; l++) {
GaussianProduct gp1 = gpAB[i][j];
GaussianProduct gp2 = gpCD[k][l];
double PQ2 = 0.0;
for (int d = 0; d < 3; d++) {
double diff = gp1.P[d] - gp2.P[d];
PQ2 += diff * diff;
}
double v = 2.0 * Math.pow(Math.PI, 2.5) * gp1.K * gp2.K
/ (gp1.p * gp2.p * Math.sqrt(gp1.p + gp2.p))
* boys0(gp1.p * gp2.p * PQ2 / (gp1.p + gp2.p));
out += A.coefs[i] * B.coefs[j] * C.coefs[k] * D.coefs[l] * v;
}
}
}
}
return out;
}
/**
* Restricted Hartree-Fock for hydrogen clusters in STO-3G.
* One contracted s function per H atom, coordinates in bohr.
* atoms.length must equal number of H atoms.
*
* @param atoms Array of [x, y, z] coordinates (all H atoms, in bohr)
* @param nelec Number of electrons (must be even for RHF)
* @param bins Number of density matrix elements to output
* @return Array of density elements + gap
*/
public static double[] solve(double[][] atoms, int nelec, int bins) {
// Validate inputs
if (atoms == null || atoms.length == 0) {
throw new IllegalArgumentException("At least one atom required");
}
if ((nelec & 1) != 0) {
throw new IllegalArgumentException("RHF solver requires an even number of electrons");
}
if (nelec <= 0 || nelec > 2 * atoms.length) {
throw new IllegalArgumentException("Invalid electron count for hydrogen-only system");
}
int n = atoms.length;
// Build shells (basis functions) and atoms (nuclei)
Shell[] shells = new Shell[n];
Atom[] nuclei = new Atom[n];
for (int i = 0; i < n; i++) {
shells[i] = new Shell(atoms[i]);
nuclei[i] = new Atom(1, atoms[i]); // Z=1 for hydrogen
}
// Build S, T, V matrices (exploit symmetry)
SimpleMatrix S = new SimpleMatrix(n, n);
SimpleMatrix T = new SimpleMatrix(n, n);
SimpleMatrix V = new SimpleMatrix(n, n);
for (int i = 0; i < n; i++) {
for (int j = 0; j <= i; j++) {
double s = contr_1e(shells[i], shells[j], QuantumBaristaEJML::S_ss);
double t = contr_1e(shells[i], shells[j], QuantumBaristaEJML::T_ss);
double v = contr_V(shells[i], shells[j], nuclei);
S.set(i, j, s); S.set(j, i, s);
T.set(i, j, t); T.set(j, i, t);
V.set(i, j, v); V.set(j, i, v);
}
}
SimpleMatrix H = T.plus(V);
// Build ERI tensor (exploit 8-fold symmetry)
double[][][][] eri = new double[n][n][n][n];
for (int i = 0; i < n; i++) {
for (int j = 0; j <= i; j++) {
for (int k = 0; k < n; k++) {
for (int l = 0; l <= k; l++) {
double v = contr_eri(shells[i], shells[j], shells[k], shells[l]);
// 8-fold permutational symmetry
eri[i][j][k][l] = v;
eri[j][i][k][l] = v;
eri[i][j][l][k] = v;
eri[j][i][l][k] = v;
eri[k][l][i][j] = v;
eri[l][k][i][j] = v;
eri[k][l][j][i] = v;
eri[l][k][j][i] = v;
}
}
}
}
// Orthogonalizer X = S^(-1/2)
SimpleMatrix X = buildOrthogonalizer(S);
// SCF iterations with core Hamiltonian initial guess
int ndocc = nelec / 2;
SimpleMatrix Fcore = H.copy();
SimpleMatrix Fp0 = X.transpose().mult(Fcore).mult(X);
EigenResult e0 = solveEigen(Fp0);
SimpleMatrix C0 = X.mult(e0.vectors);
SimpleMatrix Cocc0 = C0.extractMatrix(0, n, 0, ndocc);
SimpleMatrix D = Cocc0.mult(Cocc0.transpose()).scale(2.0);
// Nuclear repulsion
double Enuc = 0.0;
for (int i = 0; i < n; i++) {
for (int j = i + 1; j < n; j++) {
double r2 = 0.0;
for (int k = 0; k < 3; k++) {
double d = nuclei[i].R[k] - nuclei[j].R[k];
r2 += d * d;
}
Enuc += (double) nuclei[i].Z * nuclei[j].Z / Math.sqrt(r2);
}
}
double[] eps = null;
boolean converged = false;
double Eold = 0.0;
boolean firstIter = true;
// Preallocate matrices for SCF loop (reduces GC pressure)
SimpleMatrix J = new SimpleMatrix(n, n);
SimpleMatrix K = new SimpleMatrix(n, n);
SimpleMatrix F = new SimpleMatrix(n, n);
SimpleMatrix Fp = new SimpleMatrix(n, n);
SimpleMatrix C = new SimpleMatrix(n, n);
SimpleMatrix Cocc = new SimpleMatrix(n, ndocc);
SimpleMatrix Dnew = new SimpleMatrix(n, n);
for (int iter = 0; iter < 128; iter++) {
// Build Fock matrix (in-place to reuse allocated matrices)
buildJInPlace(D, eri, J, n);
buildKInPlace(D, eri, K, n);
// F = H + J - 0.5*K (scale creates new matrix to avoid mutation)
SimpleMatrix Khalf = K.scale(0.5);
F = H.plus(J).minus(Khalf);
// Transform and solve
Fp = X.transpose().mult(F).mult(X);
EigenResult eigen = solveEigen(Fp);
eps = eigen.values;
C = X.mult(eigen.vectors);
// Build density
Cocc = C.extractMatrix(0, n, 0, ndocc);
Dnew = Cocc.mult(Cocc.transpose()).scale(2.0);
// Energy
double E = 0.5 * H.plus(F).elementMult(Dnew).elementSum() + Enuc;
// Check convergence (skip first iteration to avoid Infinity)
if (!firstIter) {
double dE = Math.abs(E - Eold);
double dD = Dnew.minus(D).normF();
D = Dnew;
Eold = E;
if (dD < 1e-5 && dE < 1e-8) {
converged = true;
break;
}
} else {
D = Dnew;
Eold = E;
firstIter = false;
}
}
if (!converged) {
System.err.println("Warning: SCF did not converge in 128 iterations");
}
// Pack outputs
double gap = (ndocc < eps.length) ? eps[ndocc] - eps[ndocc - 1] : Double.NaN;
int outSize = Math.min(bins, n * n);
double[] outputs = new double[outSize + 1];
// Flatten density matrix
int idx = 0;
for (int i = 0; i < n && idx < outSize; i++) {
for (int j = 0; j < n && idx < outSize; j++) {
outputs[idx++] = D.get(i, j);
}
}
outputs[outSize] = gap;
return outputs;
}
// Build orthogonalizer S^(-1/2)
private static SimpleMatrix buildOrthogonalizer(SimpleMatrix S) {
// Enforce symmetry to improve numerical stability
SimpleMatrix Ssym = S.plus(S.transpose()).scale(0.5);
EigenResult eigen = solveEigen(Ssym);
int n = Ssym.numRows();
SimpleMatrix U = eigen.vectors;
SimpleMatrix invSqrt = new SimpleMatrix(n, n);
// Find smallest eigenvalue for diagnostic purposes
double minEigenvalue = eigen.values[0];
for (int i = 1; i < n; i++) {
minEigenvalue = Math.min(minEigenvalue, eigen.values[i]);
}
for (int i = 0; i < n; i++) {
double lam = eigen.values[i];
// Fail if eigenvalue is significantly negative (not positive definite)
if (lam < -1e-6) {
throw new IllegalStateException(
"S not positive definite. eigen[" + i + "] = " + lam +
", smallest eigenvalue = " + minEigenvalue);
}
// Clamp tiny negative noise and prevent division by zero
lam = Math.max(lam, 1e-10);
invSqrt.set(i, i, 1.0 / Math.sqrt(lam));
}
return U.mult(invSqrt).mult(U.transpose());
}
// Build Coulomb matrix J (in-place for performance)
// J_mn = sum_ls D_ls (mn|ls)
private static void buildJInPlace(SimpleMatrix D, double[][][][] eri, SimpleMatrix J, int n) {
for (int m = 0; m < n; m++) {
for (int nIdx = 0; nIdx < n; nIdx++) {
double sum = 0.0;
for (int l = 0; l < n; l++) {
for (int s = 0; s < n; s++) {
sum += D.get(l, s) * eri[m][nIdx][l][s];
}
}
J.set(m, nIdx, sum);
}
}
}
// Build exchange matrix K (in-place for performance)
// K_mn = sum_ls D_ls (ml|ns)
private static void buildKInPlace(SimpleMatrix D, double[][][][] eri, SimpleMatrix K, int n) {
for (int m = 0; m < n; m++) {
for (int nIdx = 0; nIdx < n; nIdx++) {
double sum = 0.0;
for (int l = 0; l < n; l++) {
for (int s = 0; s < n; s++) {
sum += D.get(l, s) * eri[m][l][nIdx][s];
}
}
K.set(m, nIdx, sum);
}
}
}
// Build Coulomb matrix J (allocating version - kept for compatibility)
// J_mn = sum_ls D_ls (mn|ls)
private static SimpleMatrix buildJ(SimpleMatrix D, double[][][][] eri, int n) {
SimpleMatrix J = new SimpleMatrix(n, n);
buildJInPlace(D, eri, J, n);
return J;
}
// Build exchange matrix K (allocating version - kept for compatibility)
// K_mn = sum_ls D_ls (ml|ns)
private static SimpleMatrix buildK(SimpleMatrix D, double[][][][] eri, int n) {
SimpleMatrix K = new SimpleMatrix(n, n);
buildKInPlace(D, eri, K, n);
return K;
}
// Eigendecomposition result
private static class EigenResult {
double[] values;
SimpleMatrix vectors;
EigenResult(double[] values, SimpleMatrix vectors) {
this.values = values;
this.vectors = vectors;
}
}
// Solve eigenvalue problem using SimpleMatrix API
private static EigenResult solveEigen(SimpleMatrix M) {
// SimpleMatrix.eig() returns eigenvalues and eigenvectors
org.ejml.simple.SimpleEVD<?> evd = M.eig();
int n = M.getNumRows();
// Extract eigenvalues and eigenvectors
double[] values = new double[n];
SimpleMatrix vectors = new SimpleMatrix(n, n);
for (int i = 0; i < n; i++) {
values[i] = evd.getEigenvalue(i).getReal();
var vec = evd.getEigenVector(i);
if (vec == null) {
// Should not happen for symmetric matrices, but guard anyway
throw new IllegalStateException("Null eigenvector at index " + i);
}
for (int j = 0; j < n; j++) {
vectors.set(j, i, vec.get(j, 0));
}
}
// Sort eigenvalues and eigenvectors in ascending order
Integer[] indices = new Integer[n];
for (int i = 0; i < n; i++) indices[i] = i;
java.util.Arrays.sort(indices, (a, b) -> Double.compare(values[a], values[b]));
double[] sortedValues = new double[n];
SimpleMatrix sortedVectors = new SimpleMatrix(n, n);
for (int i = 0; i < n; i++) {
int idx = indices[i];
sortedValues[i] = values[idx];
for (int j = 0; j < n; j++) {
sortedVectors.set(j, i, vectors.get(j, idx));
}
}
return new EigenResult(sortedValues, sortedVectors);
}
/**
* Convenience method: solve RHF and derive HMAC key in one call.
* Ideal for Android apps that just need the final hex key.
*
* @param atoms Array of [x, y, z] coordinates (all H atoms, in bohr)
* @param nelec Number of electrons (must be even for RHF)
* @param bins Number of density matrix elements to output
* @param seed Integer seed for HMAC derivation
* @return Hex-encoded HMAC-SHA256 key
*/
public static String solveAndDeriveKey(double[][] atoms, int nelec, int bins, int seed) throws Exception {
double[] outputs = solve(atoms, nelec, bins);
return deriveKey(seed, outputs);
}
/**
* Derive HMAC-SHA256 key from seed and solver outputs.
* Pack doubles and seed as little-endian and derive HMAC-SHA256.
*
* @param seed Integer seed (used as HMAC message, 4-byte little-endian)
* @param outputs Solver output array (density matrix elements + gap)
* @return Hex-encoded HMAC-SHA256 digest
*
* Construction:
* 1. SHA-256 hash the packed outputs → 32-byte key
* 2. HMAC-SHA256(key=hash, message=seed)
* This provides a fixed-size key regardless of output array length.
*/
public static String deriveKey(int seed, double[] outputs) throws Exception {
// Pack outputs into bytes (little-endian doubles)
ByteBuffer buf = ByteBuffer.allocate(outputs.length * 8);
buf.order(ByteOrder.LITTLE_ENDIAN);
for (double v : outputs) {
buf.putDouble(v);
}
// SHA-256 hash of outputs → 32-byte HMAC key
MessageDigest sha = MessageDigest.getInstance("SHA-256");
sha.update(buf.array());
byte[] mixed = sha.digest(); // 32 bytes
// Pack seed into 4 bytes (little-endian)
byte[] seedBytes = new byte[4];
ByteBuffer.wrap(seedBytes)
.order(ByteOrder.LITTLE_ENDIAN)
.putInt(seed);
// HMAC-SHA256(key=hash(outputs), message=seed)
Mac mac = Mac.getInstance("HmacSHA256");
mac.init(new SecretKeySpec(mixed, "HmacSHA256"));
byte[] hash = mac.doFinal(seedBytes);
// Convert to hex
StringBuilder hex = new StringBuilder(hash.length * 2);
for (byte b : hash) {
hex.append(String.format("%02x", b));
}
return hex.toString();
}
/**
* Example usage
*/
public static void main(String[] args) throws Exception {
// H2 molecule example
double[][] atoms = {
{0.0, 0.0, 0.0},
{0.0, 0.0, 1.4}
};
int nelec = 2;
int bins = 4;
int seed = 12345;
System.out.println("Solving H2 molecule...");
double[] outputs = solve(atoms, nelec, bins);
System.out.println("Density matrix elements:");
for (int i = 0; i < outputs.length - 1; i++) {
System.out.printf(" D[%d] = %.15e%n", i, outputs[i]);
}
System.out.printf(" Gap = %.15e%n", outputs[outputs.length - 1]);
String key = deriveKey(seed, outputs);
System.out.println("\nHMAC-SHA256 key:");
System.out.println(key);
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment