Skip to content

Instantly share code, notes, and snippets.

@tscholl2
Created June 24, 2020 23:03
Show Gist options
  • Save tscholl2/c5ea6b3ea6a68d35616d3f8f68b77540 to your computer and use it in GitHub Desktop.
Save tscholl2/c5ea6b3ea6a68d35616d3f8f68b77540 to your computer and use it in GitHub Desktop.
#include "linear-algebra.h"
#define mat_val(i, j) A[i * m + j]
#define swap_values(a, b) t = b, b = a, a = t
#define swap_rows(_i1, _i2) \
for (int k = 0; k < m; k++) \
swap_values(mat_val(_i1, k), mat_val(_i2, k));
int solve(Element *A, Element *b, int n, int m, Element *x)
{
Element t;
// Step 1: put in upper triangular
int s = n > m ? m : n;
int pivots = 0;
for (int j = 0; j < m && pivots < s; j++)
{
for (int i = pivots; i < n; i++)
if (!Element_is_zero(mat_val(i, j)))
{
pivots++;
swap_rows(i, j);
swap_values(b[i], b[j]);
for (int i2 = j + 1; i2 < n; i2++)
{
Element ratio = Element_div(mat_val(i2, j), mat_val(j, j));
for (int j2 = j; j2 < m; j2++)
mat_val(i2, j2) = Element_sub(
mat_val(i2, j2),
Element_mul(ratio, mat_val(j, j2)));
b[i2] = Element_sub(b[i2], Element_mul(ratio, b[j]));
}
break;
}
}
// Step 2: back substitution
for (int j = s - 1; j >= 0; j--)
{
int pivot_row = -1;
for (int i = j; i >= 0; i--)
if (!Element_is_zero(mat_val(i, j)))
{
pivot_row = i;
break;
}
if (pivot_row < 0)
continue;
x[j] = b[j];
for (int k = j + 1; k < s; k++)
x[j] = Element_sub(x[j],
Element_mul(mat_val(pivot_row, k), x[k]));
x[j] = Element_mul(
x[j],
mat_val(pivot_row, j)
? Element_div(Element_1, mat_val(pivot_row, j))
: Element_0);
}
// Step 3: verify
for (int i = 0; i < n; i++)
{
Element a = Element_0;
for (int j = 0; j < m; j++)
a = Element_add(a, Element_mul(mat_val(i, j), x[j]));
if (!Element_eq(a, b[i]))
return 0;
}
return 1;
}
#ifndef LINEAR_ALGEBRA_HEADER_H
#define LINEAR_ALGEBRA_HEADER_H
#define Element double
#define Element_0 0.0
#define Element_1 1.0
#define Element_eq(a,b) a==b
#define Element_is_zero(a) (a==0.0)
#define Element_add(a,b) a+b
#define Element_sub(a,b) a-b
#define Element_mul(a,b) a*b
#define Element_div(a,b) a/b
/**
* Give an n-by-m matrix A and m-length vector b,
* sets x to a solution Ax = b if one exists.
* Returns 1 if a solution exists and 0 else.
**/
int solve(Element *A, Element *b, int n, int m, Element *x);
#endif
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment