Skip to content

Instantly share code, notes, and snippets.

@MurageKibicho
Created October 17, 2025 09:06
Show Gist options
  • Save MurageKibicho/955133f061926b3666d314359a868081 to your computer and use it in GitHub Desktop.
Save MurageKibicho/955133f061926b3666d314359a868081 to your computer and use it in GitHub Desktop.
Sinhorn-Knopp Algorithm with total support checking
//Complete LeetArxiv walkthrough: https://leetarxiv.substack.com/p/sinkhorn-knopp-algorithm
#include <stdio.h>
#include <stdlib.h>
#include <stdbool.h>
#include <math.h>
#include <string.h>
#include <assert.h>
#define INDEX(x, y, cols) ((x) * (cols) + (y))
//clear && gcc SinkhornKnopp.c -lm -o m.o && ./m.o
void PrintSinkhornCode(int sinkhornCode)
{
if(sinkhornCode == -1)
{
printf("(%d), Matrix is not a square matrix\n",sinkhornCode);
}
else if(sinkhornCode == -2)
{
printf("(%d), Matrix row all zero \n",sinkhornCode);
}
else if(sinkhornCode == -3)
{
printf("(%d), Matrix col all zero \n",sinkhornCode);
}
else if(sinkhornCode == -4)
{
printf("(%d), Matrix has nonnegative element\n",sinkhornCode);
}
else if(sinkhornCode == -5)
{
printf("(%d), Determinant is zero\n",sinkhornCode);
}
else if(sinkhornCode == 1)
{
printf("(%d), All matrix elements are greater than zero\n",sinkhornCode);
}
else if(sinkhornCode == 2)
{
printf("(%d), Determinant is non-zero\n",sinkhornCode);
}
else
{
printf("Unknown sinkhorn code\n");
}
}
void PrintMatrix(int rows, int cols, double *matrix)
{
for(int i = 0; i < rows; i++)
{
for(int j = 0; j < cols; j++)
{
double element = matrix[INDEX(i,j, cols)];
printf("%.3f,",element);
}
printf("\n");
}
printf("\n");
}
double LUDecomposition(int n, double *inputMatrix, double *L, double *U, int *P)
{
double det = 1.0;
double *A = (double *)malloc(n * n * sizeof(double));
memcpy(A, inputMatrix, n * n * sizeof(double));
//Initialize L and U
for(int i = 0; i < n; i++){for(int j = 0; j < n; j++){L[i * n + j] = (i == j) ? 1.0 : 0.0;U[i * n + j] = 0.0;}}
//Initialize P to identity
for(int i = 0; i < n; i++){P[i] = i;}
for(int i = 0; i < n; i++)
{
//Partial pivoting
double max = fabs(A[i * n + i]);
int pivot = i;
for(int k = i + 1; k < n; k++){double val = fabs(A[k * n + i]);if(val > max){max = val;pivot = k;}}
if(max < 1e-12)
{
printf("WARNING: Singular or near-singular matrix at step %d\n", i);
printf("Max pivot value: %e\n", max);
free(A);
return 0.0;
}
if(pivot != i)
{
//Swap rows in A
for(int j = 0; j < n; j++){double temp = A[i * n + j];A[i * n + j] = A[pivot * n + j];A[pivot * n + j] = temp;}
//Swap rows in L
for(int j = 0; j < i; j++){double temp = L[i * n + j];L[i * n + j] = L[pivot * n + j];L[pivot * n + j] = temp;}
int temp = P[i];
P[i] = P[pivot];
P[pivot] = temp;
det = -det;
}
//Compute U[i][j] with NaN check
for(int j = i; j < n; j++)
{
double sum = 0.0;
for(int k = 0; k < i; k++){sum += L[i * n + k] * U[k * n + j];}
U[i * n + j] = A[i * n + j] - sum;
//CHECK FOR NaN/Inf
if(!isfinite(U[i * n + j]))
{
printf("ERROR: Non-finite U[%d][%d] = %f\n", i, j, U[i * n + j]);
printf("A[%d][%d] = %f, sum = %f\n", i, j, A[i * n + j], sum);
free(A);
return NAN;
}
}
// Compute L[j][i] with robustness check
double divisor = U[i * n + i];
if (fabs(divisor) < 1e-12)
{
printf("\nWARNING: Very small divisor at step %d: %e\n", i, divisor);
free(A);
return 0.0;
}
for(int j = i + 1; j < n; j++)
{
double sum = 0.0;
for(int k = 0; k < i; k++){sum += L[j * n + k] * U[k * n + i];}
L[j * n + i] = (A[j * n + i] - sum) / divisor;
//CHECK FOR NaN/Inf
if(!isfinite(L[j * n + i]))
{
printf("\nERROR: Non-finite L[%d][%d] = %f\n", j, i, L[j * n + i]);
printf("A[%d][%d] = %f, sum = %f, divisor = %f\n", j, i, A[j * n + i], sum, divisor);
free(A);
return NAN;
}
}
det *= U[i * n + i];
//CHECK DETERMINANT
if(!isfinite(det))
{
printf("\nERROR: Non-finite determinant at step %d: %f\n", i, det);
printf("U[%d][%d] = %f\n", i, i, U[i * n + i]);
free(A);
return det;
}
}
free(A);
return det;
}
void NaiveMatmul(int rowA, int colA, int colB,double *A, double *B, double *C)
{
for(int i = 0; i < rowA; i++)
{
for(int j = 0; j < colB; j++)
{
double sum = 0.0;
for(int k = 0; k < colA; k++)
{
sum += A[INDEX(i, k, colA)] * B[INDEX(k, j, colB)];
}
C[INDEX(i, j, colB)] = sum;
}
}
}
void FindPAMatrix(int n, int *P, double *PA, double *A)
{
for(int i = 0; i < n; i++)
{
int srcRow = P[i];
for(int j = 0; j < n; j++)
{
PA[i * n + j] = A[srcRow * n + j];
}
}
}
int SinkhornProperties(int rows, int cols, double *matrix)
{
int sinkhornTest = 0;
bool allGreaterThanZero = true;
bool allPositive = true;
bool zeroRowSum = false;
bool zeroColSum = false;
//Test 1: Check square matrix
if(rows != cols){return -1;}
//Test 2: Check if all entries are greater than 0
//Test 3: Also check if entire row or col is fully 0
double *rowSums = calloc(rows, sizeof(double));
double *colSums = calloc(cols, sizeof(double));
for(int i = 0; i < rows; i++)
{
for(int j = 0; j < cols; j++)
{
double element = matrix[INDEX(i,j, cols)];
if(element == 0.0){allGreaterThanZero = false;}
if(element < 0){allPositive = false;}
rowSums[i] += element;
colSums[j] += element;
}
}
for(int i = 0; i < rows; i++){if(rowSums[i] == 0){zeroRowSum = true;break;}}
for(int j = 0; j < cols; j++){if(colSums[j] == 0){zeroColSum = true;break;}}
free(rowSums);
free(colSums);
if(allGreaterThanZero == true)
{
return 1;
}
if(zeroRowSum == true)
{
return -2;
}
if(zeroColSum == true)
{
return -3;
}
if(allPositive == false)
{
return -4;
}
//Test 4: Check invertibility by LU Decomposition
double *lower = calloc(rows * cols, sizeof(double));
double *upper = calloc(rows * cols, sizeof(double));
double *LU = calloc(rows * cols, sizeof(double));
double *PA = calloc(rows * cols, sizeof(double));
int *permutationMatrix = calloc(rows * cols, sizeof(int));
double determinant = LUDecomposition(rows, matrix, lower, upper,permutationMatrix);
NaiveMatmul(rows, cols, rows, lower, upper, LU);
FindPAMatrix(rows, permutationMatrix, PA, matrix);
//PrintMatrix(rows, cols, PA);
//PrintMatrix(rows, cols, LU);
free(lower);free(upper);free(LU);free(PA);free(permutationMatrix);
if(determinant == 0.0f)
{
return -5;
}
else
{
return 2;
}
return sinkhornTest;
}
void TestSinkhornProperties()
{
int rowA = 3; int colA = 5;double *a = calloc(rowA * colA, sizeof(double));
int codeA = SinkhornProperties(rowA, colA, a);free(a);
assert(codeA == -1);
printf("Test 1 passed: ");PrintSinkhornCode(codeA);
int rowB = 5; int colB = 5;double *b = calloc(rowB * colB, sizeof(double));
for(int i = 0; i < rowB * colB; i++){b[i] = rand() % 10; if(i % 2 == 0){b[i] *= -1;}}
int codeB = SinkhornProperties(rowB, colB, b);free(b);
assert(codeB == -4);
printf("Test 2 passed: ");PrintSinkhornCode(codeB);
int rowC = 5; int colC = 5;double *c = calloc(rowC * colC, sizeof(double));
for(int i = 0; i < rowC * colC; i++){c[i] = rand() % 10; if(c[i] == 0){c[i] = 1;}}
int codeC = SinkhornProperties(rowC, colC, c);free(c);
assert(codeC == 1);
printf("Test 3 passed: ");PrintSinkhornCode(codeC);
int rowD = 5; int colD = 5;double *d = calloc(rowD * colD, sizeof(double));
for(int i = 0; i < rowD; i++)
{
for(int j = 0; j < colD; j++)
{
d[INDEX(i,j, colD)] = rand() % 10;
if(i == 2){d[INDEX(i,j, colD)] = 0;}
}
}
int codeD = SinkhornProperties(rowD, colD, d);free(d);
assert(codeD == -2);
printf("Test 4 passed: ");PrintSinkhornCode(codeD);
int rowE = 5; int colE = 5;double *e = calloc(rowE * colE, sizeof(double));
for(int i = 0; i < rowE; i++)
{
for(int j = 0; j < colE; j++)
{
e[INDEX(i,j, colE)] = rand() % 10;
//Make a row linearly dependent to test when determinant is zero
if(j == 4){e[INDEX(i,j, colE)] = 0;}
}
}
int codeE = SinkhornProperties(rowE, colE, e);free(e);
assert(codeE == -3);
printf("Test 5 passed: ");PrintSinkhornCode(codeE);
int rowF = 5; int colF = 5;double *f = calloc(rowF * colF, sizeof(double));
int rowG = 5; int colG = 5;double *g = calloc(rowG * colG, sizeof(double));
for(int i = 0; i < rowF; i++)
{
for(int j = 0; j < colF; j++)
{
f[INDEX(i,j, colE)] = rand() % 10;
g[INDEX(i,j, colE)] = rand() % 10;
//Make a row of g linearly dependent so determinant is zero
if(i == 1){g[INDEX(i,j, colE)] = 2 * g[INDEX(i-1,j, colE)] ;}
}
}
int codeF = SinkhornProperties(rowF, colF, f);
int codeG = SinkhornProperties(rowG, colG, g);
assert(codeF == 2);
printf("Test 6 passed: ");PrintSinkhornCode(codeF);
assert(codeG == -5);
printf("Test 7 passed: ");PrintSinkhornCode(codeG);
free(f);free(g);
}
double SinkhornKnoppAlgorithm(int rows, int cols, int sinkhornIterations, double sinkhornTolerance, double *matrix, double *matrixCopy, double *rowScaling, double *colScaling)
{
//We assume you already checked for total support
assert(rows == cols);
memcpy(matrixCopy, matrix, rows * cols * sizeof(double));
//Initialize Scaling Vectors to 1
for(int i = 0; i < rows; i++)
{
rowScaling[i] = 1.0;
colScaling[i] = 1.0;
}
int n = rows;
double maxError = 0.0;
for(int iteration = 0; iteration < sinkhornIterations; iteration++)
{
maxError = 0.0;
//Update Row Scaling
for(int i = 0; i < n; i++)
{
double rowSum = 0.0;
for(int j = 0; j < n; j++)
{
rowSum += matrix[i * n + j] * colScaling[j];
}
if(rowSum > 1e-12)
{
rowScaling[i] = 1.0 / rowSum;
}
maxError = fmax(maxError, fabs(rowSum - 1.0));
}
//Update column scaling
for(int j = 0; j < n; j++)
{
double colSum = 0.0;
for(int i = 0; i < n; i++)
{
colSum += rowScaling[i] * matrix[i * n + j];
}
if(colSum > 1e-12)
{
colScaling[j] = 1.0 / colSum;
}
maxError = fmax(maxError, fabs(colSum - 1.0));
}
if(maxError < sinkhornTolerance){break;}
}
//Apply scaling
for(int i = 0; i < n; i++)
{
for(int j = 0; j < n; j++)
{
matrixCopy[i * n + j] = rowScaling[i] * matrix[i * n + j] * colScaling[j];
}
}
return maxError;
}
void TestSinkhornBalance()
{
int rowE = 5; int colE = 5;
double *e = calloc(rowE * colE, sizeof(double));
double *eCopy = calloc(rowE * colE, sizeof(double));
double *rowScaling = calloc(rowE, sizeof(double));
double *colScaling = calloc(rowE, sizeof(double));
for(int i = 0; i < rowE; i++)
{
for(int j = 0; j < colE; j++)
{
e[INDEX(i,j, colE)] = rand() % 10;
}
}
int codeE = SinkhornProperties(rowE, colE, e);
assert(codeE > 0);
int sinkhornIterations = 4000;
double sinkhornTolerance = 1e-10;
double sinkhornError = SinkhornKnoppAlgorithm(rowE, colE, sinkhornIterations, sinkhornTolerance, e, eCopy, rowScaling, colScaling);
printf("\nSinkhorn Error: %.3f\n", sinkhornError);
PrintMatrix(rowE, colE, e);
PrintMatrix(rowE, colE, eCopy);
free(e);free(eCopy);free(rowScaling);free(colScaling);
}
int main()
{
TestSinkhornProperties();
TestSinkhornBalance();
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment