Skip to content

Instantly share code, notes, and snippets.

@basp1
Last active August 15, 2017 06:03
Show Gist options
  • Save basp1/cd9c407f167496175cae641d776bb047 to your computer and use it in GitHub Desktop.
Save basp1/cd9c407f167496175cae641d776bb047 to your computer and use it in GitHub Desktop.
Simple sparse matrix class
using System;
using System.Collections.Generic;
using System.Linq;
using System.Diagnostics;
namespace Matlib
{
public class SparseMatrix<T>
{
public enum SparseMatrixType
{
UNSYMMENTRIC, UPPER_SYMMENTRIC, LOWER_SYMMENTRIC
}
readonly static int NIL = -1;
SparseMatrixType type;
List<int> rows;
List<T> data;
List<int> columns;
List<int> next;
List<int> last;
List<int> count;
int ie;
int nnz;
int rowCount = 0;
int columnCount = 0;
public SparseMatrix(SparseMatrix<T> other)
{
this.rows = new List<int>(other.rows);
this.data = new List<T>(other.data);
this.columns = new List<int>(other.columns);
this.next = new List<int>(other.next);
this.last = new List<int>(other.last);
this.count = new List<int>(other.count);
this.ie = other.ie;
this.nnz = other.nnz;
this.rowCount = other.rowCount;
this.columnCount = other.columnCount;
}
public SparseMatrix(int rowCount, int columnCount = int.MaxValue)
{
Debug.Assert(rowCount > 0);
Debug.Assert(columnCount > 0);
this.rows = new List<int>(rowCount);
this.last = new List<int>(rowCount);
this.count = new List<int>(rowCount);
for (int i = 0; i < rowCount; i++)
{
last.Add(NIL);
rows.Add(NIL);
count.Add(0);
}
this.data = new List<T>(rowCount); //
this.columns = new List<int>(rowCount); //
this.next = new List<int>(rowCount); //
this.ie = NIL;
this.nnz = 0;
this.rowCount = rowCount;
this.columnCount = columnCount;
}
public static SparseMatrix<double> Eye(int size)
{
var spm = new SparseMatrix<double>(size, size);
for (int i = 0; i < size; i++)
{
spm.PushBack(i, i, 1.0);
}
return spm;
}
public SparseMatrixType Type { get { return type; } }
public int RowCount { get { return rowCount; } }
public int ColumnCount { get { return columnCount; } }
public int Nnz { get { return nnz; } }
public override bool Equals(object obj)
{
if (!(obj is SparseMatrix<T>))
{
throw new ArgumentException();
}
var other = (SparseMatrix<T>)obj;
if (Nnz != other.Nnz)
{
return false;
}
if (RowCount != other.RowCount)
{
return false;
}
if (ColumnCount != other.ColumnCount)
{
return false;
}
if (!count.SequenceEqual(other.count))
{
return false;
}
for (int i = 0; i < RowCount; i++)
{
for (int j = rows[i], other_j = other.rows[i]; NIL != j && NIL != other_j; j = next[j], other_j = other.next[other_j])
{
if (columns[j] != other.columns[other_j])
{
return false;
}
if (!data[j].Equals(other.data[other_j]))
{
return false;
}
}
}
return true;
}
public int Count(int row)
{
Debug.Assert(row < rowCount);
return count[row];
}
public void Clear()
{
this.ie = NIL;
this.nnz = 0;
for (int i = 0; i < rowCount; i++)
{
this.rows[i] = NIL;
this.count[i] = 0;
this.last[i] = NIL;
}
this.data.Clear();
this.columns.Clear();
this.next.Clear();
}
public Cell<T> First(int row)
{
Debug.Assert(Count(row) >= 1);
return new Cell<T> { Column = columns[rows[row]], Value = data[rows[row]] };
}
public Cell<T> Second(int row)
{
Debug.Assert(Count(row) >= 2);
return new Cell<T> { Column = columns[next[rows[row]]], Value = data[next[rows[row]]] };
}
public void CopyRow(int row, ref List<Cell<T>> list)
{
Debug.Assert(row < rowCount);
list.Clear();
list.Capacity = Count(row);
for (int j = rows[row], i = 0; NIL != j; j = next[j], i++)
{
list.Add(new Cell<T>() { Column = columns[j], Value = data[j] });
}
}
public List<Cell<T>> CopyRow(int row)
{
var list = new List<Cell<T>>(Count(row));
CopyRow(row, ref list);
return list;
}
public bool Contains(int row, Predicate<Cell<T>> predicate)
{
Debug.Assert(row < rowCount);
for (int j = rows[row]; NIL != j; j = next[j])
{
if (predicate(new Cell<T>() { Column = columns[j], Value = data[j] }))
{
return true;
}
}
return false;
}
public void Push(int row, Cell<T> cell)
{
Push(row, cell.Column, cell.Value);
}
public void Push(int row, int column, T value)
{
Debug.Assert(row < rowCount);
Debug.Assert(column < columnCount);
int j;
if (ie >= 0)
{
j = ie;
data[ie] = value;
columns[ie] = column;
ie = next[ie];
}
else
{
j = data.Count;
data.Add(value);
columns.Add(column);
next.Add(NIL);
}
if (0 == Count(row))
{
last[row] = j;
rows[row] = j;
next[j] = NIL;
}
else
{
next[j] = rows[row];
rows[row] = j;
}
count[row] += 1;
this.nnz += 1;
}
public void PushBack(int row, Cell<T> cell)
{
PushBack(row, cell.Column, cell.Value);
}
public void PushBack(int row, int column, T value)
{
int n;
if (ie >= 0)
{
n = ie;
data[ie] = value;
columns[ie] = column;
ie = next[ie];
}
else
{
n = nnz;
next.Add(NIL);
data.Add(value);
columns.Add(column);
}
if (NIL == rows[row])
{
rows[row] = n;
}
else
{
next[last[row]] = n;
}
last[row] = n;
next[n] = NIL;
count[row] += 1;
nnz += 1;
}
public void RemoveRow(int row)
{
Debug.Assert(row < rowCount);
if (0 == Count(row))
{
return;
}
else
{
for (int j = rows[row]; NIL != j; j = next[j])
{
data[j] = default(T);
}
next[last[row]] = ie;
ie = rows[row];
nnz -= count[row];
rows[row] = NIL;
last[row] = NIL;
count[row] = 0;
}
}
public void MergeRows(int dstRow, int srcRow)
{
Debug.Assert(dstRow < rowCount);
Debug.Assert(srcRow < rowCount);
if (0 == Count(srcRow))
{
return;
}
int j = rows[srcRow];
if (0 == Count(dstRow))
{
rows[dstRow] = rows[srcRow];
}
else
{
next[last[dstRow]] = rows[srcRow];
}
last[dstRow] = last[srcRow];
count[dstRow] += count[srcRow];
rows[srcRow] = NIL;
last[srcRow] = NIL;
count[srcRow] = 0;
}
public SparseMatrix<T> Transpose()
{
Debug.Assert(int.MaxValue != columnCount);
var t = new SparseMatrix<T>(columnCount, rowCount);
for (int i = 0; i < rowCount; i++)
{
if (NIL == rows[i]) continue;
for (int j = rows[i]; NIL != j; j = next[j])
{
t.PushBack(columns[j], i, data[j]);
}
}
return t;
}
public void SortRows()
{
var list = new List<Cell<T>>();
for (int i = 0; i < rowCount; i++)
{
if (NIL == rows[i]) continue;
CopyRow(i, ref list);
list.Sort((x, y) => x.Column.CompareTo(y.Column));
RemoveRow(i);
for (int j = 0; j < list.Count; j++)
{
PushBack(i, list[j]);
}
}
}
public struct Cell<T>
{
public int Column;
public T Value;
}
}
}
using System;
using Matlib;
using Microsoft.VisualStudio.TestTools.UnitTesting;
namespace Matlib.Tests
{
[TestClass]
public class SparseMatrixTests
{
[TestMethod]
public void SparseMatrixPush()
{
var spm = new SparseMatrix<char>(3);
Assert.AreEqual(0, spm.Count(0));
Assert.AreEqual(0, spm.Count(1));
Assert.AreEqual(0, spm.Count(2));
Assert.AreEqual(0, spm.Nnz);
spm.Push(1, 0, 'a');
spm.Push(1, 0, 'b');
spm.Push(1, 0, 'c');
spm.Push(1, 0, 'd');
Assert.AreEqual(0, spm.Count(0));
Assert.AreEqual(4, spm.Count(1));
Assert.AreEqual(0, spm.Count(2));
Assert.AreEqual(4, spm.Nnz);
var list = spm.CopyRow(0);
Assert.AreEqual(0, list.Count);
list = spm.CopyRow(1);
Assert.AreEqual(4, list.Count);
Assert.AreEqual('d', list[0].Value);
Assert.AreEqual('c', list[1].Value);
Assert.AreEqual('b', list[2].Value);
Assert.AreEqual('a', list[3].Value);
}
public void SparseMatrixPushBack()
{
var spm = new SparseMatrix<char>(3);
Assert.AreEqual(0, spm.Count(0));
Assert.AreEqual(0, spm.Count(1));
Assert.AreEqual(0, spm.Count(2));
Assert.AreEqual(0, spm.Nnz);
spm.PushBack(1, 0, 'a');
spm.PushBack(1, 0, 'b');
spm.PushBack(1, 0, 'c');
spm.PushBack(1, 0, 'd');
Assert.AreEqual(0, spm.Count(0));
Assert.AreEqual(4, spm.Count(1));
Assert.AreEqual(0, spm.Count(2));
Assert.AreEqual(4, spm.Nnz);
var list = spm.CopyRow(0);
Assert.AreEqual(0, list.Count);
list = spm.CopyRow(1);
Assert.AreEqual(4, list.Count);
Assert.AreEqual('a', list[0].Value);
Assert.AreEqual('b', list[1].Value);
Assert.AreEqual('c', list[2].Value);
Assert.AreEqual('d', list[3].Value);
}
[TestMethod]
public void SparseMatrixContains()
{
var spm = new SparseMatrix<char>(3);
spm.Push(1, 0, 'a');
spm.Push(1, 0, 'b');
spm.Push(1, 0, 'b');
spm.Push(1, 0, 'a');
Assert.IsTrue(spm.Contains(1, cell => 'a' == cell.Value));
Assert.IsTrue(spm.Contains(1, cell => 'b' == cell.Value));
Assert.IsFalse(spm.Contains(1, cell => 'c' == cell.Value));
}
[TestMethod]
public void SparseMatrixRemoveRow()
{
var spm = new SparseMatrix<char>(3);
spm.Push(1, 0, 'a');
spm.Push(1, 0, 'b');
spm.Push(1, 0, 'b');
spm.Push(1, 0, 'a');
spm.RemoveRow(0);
Assert.AreEqual(0, spm.Count(0));
Assert.AreEqual(4, spm.Count(1));
Assert.AreEqual(0, spm.Count(2));
Assert.AreEqual(4, spm.Nnz);
spm.RemoveRow(1);
Assert.AreEqual(0, spm.Count(0));
Assert.AreEqual(0, spm.Count(1));
Assert.AreEqual(0, spm.Count(2));
Assert.AreEqual(0, spm.Nnz);
spm.Push(1, 0, 'a');
spm.Push(1, 0, 'b');
spm.Push(1, 0, 'b');
spm.Push(1, 0, 'a');
var list = spm.CopyRow(1);
Assert.AreEqual(4, list.Count);
Assert.AreEqual('a', list[0].Value);
Assert.AreEqual('b', list[1].Value);
Assert.AreEqual('b', list[2].Value);
Assert.AreEqual('a', list[3].Value);
}
[TestMethod]
public void SparseMatrixClear()
{
var spm = new SparseMatrix<char>(3);
spm.Push(1, 0, 'a');
spm.Push(1, 0, 'b');
spm.Push(1, 0, 'b');
spm.Push(1, 0, 'a');
spm.Clear();
Assert.AreEqual(0, spm.Count(0));
Assert.AreEqual(0, spm.Count(1));
Assert.AreEqual(0, spm.Count(2));
Assert.AreEqual(0, spm.Nnz);
}
[TestMethod]
public void SparseMatrixMergeRows()
{
var spm = new SparseMatrix<char>(3);
spm.Push(1, 0, 'a');
spm.Push(1, 0, 'b');
spm.Push(1, 0, 'b');
spm.Push(1, 0, 'a');
Assert.AreEqual(0, spm.Count(0));
Assert.AreEqual(4, spm.Count(1));
Assert.AreEqual(0, spm.Count(2));
Assert.AreEqual(4, spm.Nnz);
spm.MergeRows(0, 1);
Assert.AreEqual(4, spm.Count(0));
Assert.AreEqual(0, spm.Count(1));
Assert.AreEqual(0, spm.Count(2));
Assert.AreEqual(4, spm.Nnz);
var list = spm.CopyRow(0);
Assert.AreEqual(4, list.Count);
Assert.AreEqual('a', list[0].Value);
Assert.AreEqual('b', list[1].Value);
Assert.AreEqual('b', list[2].Value);
Assert.AreEqual('a', list[3].Value);
}
public void SparseMatrixMergeRows2()
{
var spm = new SparseMatrix<char>(3);
spm.PushBack(1, 0, 'a');
spm.PushBack(1, 0, 'b');
spm.PushBack(1, 0, 'b');
spm.PushBack(1, 0, 'a');
Assert.AreEqual(0, spm.Count(0));
Assert.AreEqual(4, spm.Count(1));
Assert.AreEqual(0, spm.Count(2));
Assert.AreEqual(4, spm.Nnz);
spm.MergeRows(0, 1);
Assert.AreEqual(4, spm.Count(0));
Assert.AreEqual(0, spm.Count(1));
Assert.AreEqual(0, spm.Count(2));
Assert.AreEqual(4, spm.Nnz);
var list = spm.CopyRow(0);
Assert.AreEqual(4, list.Count);
Assert.AreEqual('a', list[0].Value);
Assert.AreEqual('b', list[1].Value);
Assert.AreEqual('b', list[2].Value);
Assert.AreEqual('a', list[3].Value);
}
[TestMethod]
public void SparseMatrixSortRows()
{
var spm = new SparseMatrix<int>(3);
spm.Push(1, 3, 3);
spm.Push(1, 1, 1);
spm.Push(1, 0, 0);
spm.Push(1, 2, 2);
var list = spm.CopyRow(1);
Assert.AreEqual(4, list.Count);
Assert.AreEqual(2, list[0].Column);
Assert.AreEqual(0, list[1].Column);
Assert.AreEqual(1, list[2].Column);
Assert.AreEqual(3, list[3].Column);
spm.SortRows();
list = spm.CopyRow(1);
Assert.AreEqual(4, list.Count);
Assert.AreEqual(0, list[0].Column);
Assert.AreEqual(1, list[1].Column);
Assert.AreEqual(2, list[2].Column);
Assert.AreEqual(3, list[3].Column);
}
[TestMethod]
public void SparseMatrixTranspose()
{
var m = new SparseMatrix<int>(2, 3);
m.PushBack(0, 0, 0); m.PushBack(0, 1, 1); m.PushBack(0, 2, 2);
m.PushBack(1, 0, 3); m.PushBack(1, 1, 4); m.PushBack(1, 2, 5);
var t = m.Transpose();
var tt = t.Transpose();
var ttt = tt.Transpose();
Assert.IsFalse(m.Equals(t));
Assert.IsTrue(m.Equals(tt));
Assert.IsTrue(t.Equals(ttt));
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment