Created
October 17, 2025 09:06
-
-
Save MurageKibicho/955133f061926b3666d314359a868081 to your computer and use it in GitHub Desktop.
Sinhorn-Knopp Algorithm with total support checking
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| //Complete LeetArxiv walkthrough: https://leetarxiv.substack.com/p/sinkhorn-knopp-algorithm | |
| #include <stdio.h> | |
| #include <stdlib.h> | |
| #include <stdbool.h> | |
| #include <math.h> | |
| #include <string.h> | |
| #include <assert.h> | |
| #define INDEX(x, y, cols) ((x) * (cols) + (y)) | |
| //clear && gcc SinkhornKnopp.c -lm -o m.o && ./m.o | |
| void PrintSinkhornCode(int sinkhornCode) | |
| { | |
| if(sinkhornCode == -1) | |
| { | |
| printf("(%d), Matrix is not a square matrix\n",sinkhornCode); | |
| } | |
| else if(sinkhornCode == -2) | |
| { | |
| printf("(%d), Matrix row all zero \n",sinkhornCode); | |
| } | |
| else if(sinkhornCode == -3) | |
| { | |
| printf("(%d), Matrix col all zero \n",sinkhornCode); | |
| } | |
| else if(sinkhornCode == -4) | |
| { | |
| printf("(%d), Matrix has nonnegative element\n",sinkhornCode); | |
| } | |
| else if(sinkhornCode == -5) | |
| { | |
| printf("(%d), Determinant is zero\n",sinkhornCode); | |
| } | |
| else if(sinkhornCode == 1) | |
| { | |
| printf("(%d), All matrix elements are greater than zero\n",sinkhornCode); | |
| } | |
| else if(sinkhornCode == 2) | |
| { | |
| printf("(%d), Determinant is non-zero\n",sinkhornCode); | |
| } | |
| else | |
| { | |
| printf("Unknown sinkhorn code\n"); | |
| } | |
| } | |
| void PrintMatrix(int rows, int cols, double *matrix) | |
| { | |
| for(int i = 0; i < rows; i++) | |
| { | |
| for(int j = 0; j < cols; j++) | |
| { | |
| double element = matrix[INDEX(i,j, cols)]; | |
| printf("%.3f,",element); | |
| } | |
| printf("\n"); | |
| } | |
| printf("\n"); | |
| } | |
| double LUDecomposition(int n, double *inputMatrix, double *L, double *U, int *P) | |
| { | |
| double det = 1.0; | |
| double *A = (double *)malloc(n * n * sizeof(double)); | |
| memcpy(A, inputMatrix, n * n * sizeof(double)); | |
| //Initialize L and U | |
| for(int i = 0; i < n; i++){for(int j = 0; j < n; j++){L[i * n + j] = (i == j) ? 1.0 : 0.0;U[i * n + j] = 0.0;}} | |
| //Initialize P to identity | |
| for(int i = 0; i < n; i++){P[i] = i;} | |
| for(int i = 0; i < n; i++) | |
| { | |
| //Partial pivoting | |
| double max = fabs(A[i * n + i]); | |
| int pivot = i; | |
| for(int k = i + 1; k < n; k++){double val = fabs(A[k * n + i]);if(val > max){max = val;pivot = k;}} | |
| if(max < 1e-12) | |
| { | |
| printf("WARNING: Singular or near-singular matrix at step %d\n", i); | |
| printf("Max pivot value: %e\n", max); | |
| free(A); | |
| return 0.0; | |
| } | |
| if(pivot != i) | |
| { | |
| //Swap rows in A | |
| for(int j = 0; j < n; j++){double temp = A[i * n + j];A[i * n + j] = A[pivot * n + j];A[pivot * n + j] = temp;} | |
| //Swap rows in L | |
| for(int j = 0; j < i; j++){double temp = L[i * n + j];L[i * n + j] = L[pivot * n + j];L[pivot * n + j] = temp;} | |
| int temp = P[i]; | |
| P[i] = P[pivot]; | |
| P[pivot] = temp; | |
| det = -det; | |
| } | |
| //Compute U[i][j] with NaN check | |
| for(int j = i; j < n; j++) | |
| { | |
| double sum = 0.0; | |
| for(int k = 0; k < i; k++){sum += L[i * n + k] * U[k * n + j];} | |
| U[i * n + j] = A[i * n + j] - sum; | |
| //CHECK FOR NaN/Inf | |
| if(!isfinite(U[i * n + j])) | |
| { | |
| printf("ERROR: Non-finite U[%d][%d] = %f\n", i, j, U[i * n + j]); | |
| printf("A[%d][%d] = %f, sum = %f\n", i, j, A[i * n + j], sum); | |
| free(A); | |
| return NAN; | |
| } | |
| } | |
| // Compute L[j][i] with robustness check | |
| double divisor = U[i * n + i]; | |
| if (fabs(divisor) < 1e-12) | |
| { | |
| printf("\nWARNING: Very small divisor at step %d: %e\n", i, divisor); | |
| free(A); | |
| return 0.0; | |
| } | |
| for(int j = i + 1; j < n; j++) | |
| { | |
| double sum = 0.0; | |
| for(int k = 0; k < i; k++){sum += L[j * n + k] * U[k * n + i];} | |
| L[j * n + i] = (A[j * n + i] - sum) / divisor; | |
| //CHECK FOR NaN/Inf | |
| if(!isfinite(L[j * n + i])) | |
| { | |
| printf("\nERROR: Non-finite L[%d][%d] = %f\n", j, i, L[j * n + i]); | |
| printf("A[%d][%d] = %f, sum = %f, divisor = %f\n", j, i, A[j * n + i], sum, divisor); | |
| free(A); | |
| return NAN; | |
| } | |
| } | |
| det *= U[i * n + i]; | |
| //CHECK DETERMINANT | |
| if(!isfinite(det)) | |
| { | |
| printf("\nERROR: Non-finite determinant at step %d: %f\n", i, det); | |
| printf("U[%d][%d] = %f\n", i, i, U[i * n + i]); | |
| free(A); | |
| return det; | |
| } | |
| } | |
| free(A); | |
| return det; | |
| } | |
| void NaiveMatmul(int rowA, int colA, int colB,double *A, double *B, double *C) | |
| { | |
| for(int i = 0; i < rowA; i++) | |
| { | |
| for(int j = 0; j < colB; j++) | |
| { | |
| double sum = 0.0; | |
| for(int k = 0; k < colA; k++) | |
| { | |
| sum += A[INDEX(i, k, colA)] * B[INDEX(k, j, colB)]; | |
| } | |
| C[INDEX(i, j, colB)] = sum; | |
| } | |
| } | |
| } | |
| void FindPAMatrix(int n, int *P, double *PA, double *A) | |
| { | |
| for(int i = 0; i < n; i++) | |
| { | |
| int srcRow = P[i]; | |
| for(int j = 0; j < n; j++) | |
| { | |
| PA[i * n + j] = A[srcRow * n + j]; | |
| } | |
| } | |
| } | |
| int SinkhornProperties(int rows, int cols, double *matrix) | |
| { | |
| int sinkhornTest = 0; | |
| bool allGreaterThanZero = true; | |
| bool allPositive = true; | |
| bool zeroRowSum = false; | |
| bool zeroColSum = false; | |
| //Test 1: Check square matrix | |
| if(rows != cols){return -1;} | |
| //Test 2: Check if all entries are greater than 0 | |
| //Test 3: Also check if entire row or col is fully 0 | |
| double *rowSums = calloc(rows, sizeof(double)); | |
| double *colSums = calloc(cols, sizeof(double)); | |
| for(int i = 0; i < rows; i++) | |
| { | |
| for(int j = 0; j < cols; j++) | |
| { | |
| double element = matrix[INDEX(i,j, cols)]; | |
| if(element == 0.0){allGreaterThanZero = false;} | |
| if(element < 0){allPositive = false;} | |
| rowSums[i] += element; | |
| colSums[j] += element; | |
| } | |
| } | |
| for(int i = 0; i < rows; i++){if(rowSums[i] == 0){zeroRowSum = true;break;}} | |
| for(int j = 0; j < cols; j++){if(colSums[j] == 0){zeroColSum = true;break;}} | |
| free(rowSums); | |
| free(colSums); | |
| if(allGreaterThanZero == true) | |
| { | |
| return 1; | |
| } | |
| if(zeroRowSum == true) | |
| { | |
| return -2; | |
| } | |
| if(zeroColSum == true) | |
| { | |
| return -3; | |
| } | |
| if(allPositive == false) | |
| { | |
| return -4; | |
| } | |
| //Test 4: Check invertibility by LU Decomposition | |
| double *lower = calloc(rows * cols, sizeof(double)); | |
| double *upper = calloc(rows * cols, sizeof(double)); | |
| double *LU = calloc(rows * cols, sizeof(double)); | |
| double *PA = calloc(rows * cols, sizeof(double)); | |
| int *permutationMatrix = calloc(rows * cols, sizeof(int)); | |
| double determinant = LUDecomposition(rows, matrix, lower, upper,permutationMatrix); | |
| NaiveMatmul(rows, cols, rows, lower, upper, LU); | |
| FindPAMatrix(rows, permutationMatrix, PA, matrix); | |
| //PrintMatrix(rows, cols, PA); | |
| //PrintMatrix(rows, cols, LU); | |
| free(lower);free(upper);free(LU);free(PA);free(permutationMatrix); | |
| if(determinant == 0.0f) | |
| { | |
| return -5; | |
| } | |
| else | |
| { | |
| return 2; | |
| } | |
| return sinkhornTest; | |
| } | |
| void TestSinkhornProperties() | |
| { | |
| int rowA = 3; int colA = 5;double *a = calloc(rowA * colA, sizeof(double)); | |
| int codeA = SinkhornProperties(rowA, colA, a);free(a); | |
| assert(codeA == -1); | |
| printf("Test 1 passed: ");PrintSinkhornCode(codeA); | |
| int rowB = 5; int colB = 5;double *b = calloc(rowB * colB, sizeof(double)); | |
| for(int i = 0; i < rowB * colB; i++){b[i] = rand() % 10; if(i % 2 == 0){b[i] *= -1;}} | |
| int codeB = SinkhornProperties(rowB, colB, b);free(b); | |
| assert(codeB == -4); | |
| printf("Test 2 passed: ");PrintSinkhornCode(codeB); | |
| int rowC = 5; int colC = 5;double *c = calloc(rowC * colC, sizeof(double)); | |
| for(int i = 0; i < rowC * colC; i++){c[i] = rand() % 10; if(c[i] == 0){c[i] = 1;}} | |
| int codeC = SinkhornProperties(rowC, colC, c);free(c); | |
| assert(codeC == 1); | |
| printf("Test 3 passed: ");PrintSinkhornCode(codeC); | |
| int rowD = 5; int colD = 5;double *d = calloc(rowD * colD, sizeof(double)); | |
| for(int i = 0; i < rowD; i++) | |
| { | |
| for(int j = 0; j < colD; j++) | |
| { | |
| d[INDEX(i,j, colD)] = rand() % 10; | |
| if(i == 2){d[INDEX(i,j, colD)] = 0;} | |
| } | |
| } | |
| int codeD = SinkhornProperties(rowD, colD, d);free(d); | |
| assert(codeD == -2); | |
| printf("Test 4 passed: ");PrintSinkhornCode(codeD); | |
| int rowE = 5; int colE = 5;double *e = calloc(rowE * colE, sizeof(double)); | |
| for(int i = 0; i < rowE; i++) | |
| { | |
| for(int j = 0; j < colE; j++) | |
| { | |
| e[INDEX(i,j, colE)] = rand() % 10; | |
| //Make a row linearly dependent to test when determinant is zero | |
| if(j == 4){e[INDEX(i,j, colE)] = 0;} | |
| } | |
| } | |
| int codeE = SinkhornProperties(rowE, colE, e);free(e); | |
| assert(codeE == -3); | |
| printf("Test 5 passed: ");PrintSinkhornCode(codeE); | |
| int rowF = 5; int colF = 5;double *f = calloc(rowF * colF, sizeof(double)); | |
| int rowG = 5; int colG = 5;double *g = calloc(rowG * colG, sizeof(double)); | |
| for(int i = 0; i < rowF; i++) | |
| { | |
| for(int j = 0; j < colF; j++) | |
| { | |
| f[INDEX(i,j, colE)] = rand() % 10; | |
| g[INDEX(i,j, colE)] = rand() % 10; | |
| //Make a row of g linearly dependent so determinant is zero | |
| if(i == 1){g[INDEX(i,j, colE)] = 2 * g[INDEX(i-1,j, colE)] ;} | |
| } | |
| } | |
| int codeF = SinkhornProperties(rowF, colF, f); | |
| int codeG = SinkhornProperties(rowG, colG, g); | |
| assert(codeF == 2); | |
| printf("Test 6 passed: ");PrintSinkhornCode(codeF); | |
| assert(codeG == -5); | |
| printf("Test 7 passed: ");PrintSinkhornCode(codeG); | |
| free(f);free(g); | |
| } | |
| double SinkhornKnoppAlgorithm(int rows, int cols, int sinkhornIterations, double sinkhornTolerance, double *matrix, double *matrixCopy, double *rowScaling, double *colScaling) | |
| { | |
| //We assume you already checked for total support | |
| assert(rows == cols); | |
| memcpy(matrixCopy, matrix, rows * cols * sizeof(double)); | |
| //Initialize Scaling Vectors to 1 | |
| for(int i = 0; i < rows; i++) | |
| { | |
| rowScaling[i] = 1.0; | |
| colScaling[i] = 1.0; | |
| } | |
| int n = rows; | |
| double maxError = 0.0; | |
| for(int iteration = 0; iteration < sinkhornIterations; iteration++) | |
| { | |
| maxError = 0.0; | |
| //Update Row Scaling | |
| for(int i = 0; i < n; i++) | |
| { | |
| double rowSum = 0.0; | |
| for(int j = 0; j < n; j++) | |
| { | |
| rowSum += matrix[i * n + j] * colScaling[j]; | |
| } | |
| if(rowSum > 1e-12) | |
| { | |
| rowScaling[i] = 1.0 / rowSum; | |
| } | |
| maxError = fmax(maxError, fabs(rowSum - 1.0)); | |
| } | |
| //Update column scaling | |
| for(int j = 0; j < n; j++) | |
| { | |
| double colSum = 0.0; | |
| for(int i = 0; i < n; i++) | |
| { | |
| colSum += rowScaling[i] * matrix[i * n + j]; | |
| } | |
| if(colSum > 1e-12) | |
| { | |
| colScaling[j] = 1.0 / colSum; | |
| } | |
| maxError = fmax(maxError, fabs(colSum - 1.0)); | |
| } | |
| if(maxError < sinkhornTolerance){break;} | |
| } | |
| //Apply scaling | |
| for(int i = 0; i < n; i++) | |
| { | |
| for(int j = 0; j < n; j++) | |
| { | |
| matrixCopy[i * n + j] = rowScaling[i] * matrix[i * n + j] * colScaling[j]; | |
| } | |
| } | |
| return maxError; | |
| } | |
| void TestSinkhornBalance() | |
| { | |
| int rowE = 5; int colE = 5; | |
| double *e = calloc(rowE * colE, sizeof(double)); | |
| double *eCopy = calloc(rowE * colE, sizeof(double)); | |
| double *rowScaling = calloc(rowE, sizeof(double)); | |
| double *colScaling = calloc(rowE, sizeof(double)); | |
| for(int i = 0; i < rowE; i++) | |
| { | |
| for(int j = 0; j < colE; j++) | |
| { | |
| e[INDEX(i,j, colE)] = rand() % 10; | |
| } | |
| } | |
| int codeE = SinkhornProperties(rowE, colE, e); | |
| assert(codeE > 0); | |
| int sinkhornIterations = 4000; | |
| double sinkhornTolerance = 1e-10; | |
| double sinkhornError = SinkhornKnoppAlgorithm(rowE, colE, sinkhornIterations, sinkhornTolerance, e, eCopy, rowScaling, colScaling); | |
| printf("\nSinkhorn Error: %.3f\n", sinkhornError); | |
| PrintMatrix(rowE, colE, e); | |
| PrintMatrix(rowE, colE, eCopy); | |
| free(e);free(eCopy);free(rowScaling);free(colScaling); | |
| } | |
| int main() | |
| { | |
| TestSinkhornProperties(); | |
| TestSinkhornBalance(); | |
| return 0; | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment