Skip to content

Instantly share code, notes, and snippets.

@kurtbrose
Last active December 13, 2018 23:44
Show Gist options
  • Select an option

  • Save kurtbrose/fad1baf1b4cff07f9295fb2377d76663 to your computer and use it in GitHub Desktop.

Select an option

Save kurtbrose/fad1baf1b4cff07f9295fb2377d76663 to your computer and use it in GitHub Desktop.
matrices in a sqlite database (mostly as a prototype for linear algebra in larger SQL databases)
'''
implementation of some linear algebra in SQL
'''
import sqlite3
import time
import weakref
_MISSING = object()
class MatrixDB(object):
'''
Serves two purposes --
1 - stores underlying data
2 - serves as a factory for Matrix instances
matrices are stored as sparse -- missing elements are implicit 0
'''
def __init__(self):
self.db = sqlite3.connect(':memory:')
for statement in SCHEMA.split(';'):
statement = statement.strip()
if statement:
self.db.execute(statement)
def execute(self, sql):
'''pass-through'''
return self.db.execute(sql)
def sql(self, sql, args=None):
'''run sql against underlying DB, fetch and returns results'''
if args is None:
return self.db.execute(sql).fetchall()
return self.db.execute(sql, args).fetchall()
def sql_val(self, sql, args=None, default=_MISSING):
'''run SELECT sql that returns single value'''
result = self.sql(sql, args)
if not result:
if default is not _MISSING:
return default
raise ValueError("sql returned no rows", sql, args)
return result[0][0]
def sql_col(self, sql, args=None):
'''run SELECT and return [a, b, c] instead of [(a,), (b,), (c,)]'''
return list(sum(self.sql(sql, args), ()))
def sql_row(self, sql, args=None):
'''run SELECT and return [a, b, c] instead of [(a, b, c)]'''
return self.sql(sql, args)[0]
def make_zero(self, name, nrows, ncols):
'''create a zero matrix'''
return Matrix(self, self._make_zero(name, nrows, ncols))
def _make_zero(self, name, nrows, ncols):
self.sql('INSERT INTO matrix (name, nrows, ncols) VALUES (?, ?, ?)', (name, nrows, ncols))
return self.sql_val('SELECT max(ROWID) FROM matrix')
def make_identity(self, name, size):
'''create identity matrix'''
matrix_id = self._make_zero(name, size, size)
for i in range(size):
self.sql(
'INSERT INTO element (matrix_id, row, col, val) VALUES (?, ?, ?, 1)',
(matrix_id, i, i))
return Matrix(self, matrix_id)
def rows2matrix(self, name, rowlist):
if not rowlist:
raise ValueError('cannot construct 0x0 matrix')
if len(set([len(row) for row in rowlist])) != 1:
raise ValueError('all rows must have the same length')
nrows = len(rowlist)
ncols = len(rowlist[0])
matrix_id = self._make_zero(name, nrows, ncols)
for row in range(nrows):
for col in range(ncols):
self.set_val(matrix_id, rowlist[row][col], row, col)
return Matrix(self, matrix_id)
def _name2id(self, name):
if type(name) in (int, long):
return name
return self.sql_val('SELECT ROWID FROM matrix WHERE name = ?', (name,))
def matrix_dims(self, name):
'''return (nrows, ncols)'''
matrix_id = self._name2id(name)
return self.sql_row('SELECT nrows, ncols FROM matrix WHERE ROWID = ?', (matrix_id,))
def matrix_cost(self, name):
'''return the number of rows in the database used to represent the matrix'''
matrix_id = self._name2id(name)
return self.sql_val('SELECT count(*) FROM element WHERE matrix_id = ?', (matrix_id,))
def _clean(self, matrix_id):
self.sql('DELETE FROM element WHERE matrix_id = ? AND val = 0', (matrix_id,))
def mult(self, lhs_name, rhs_name, result_name=None):
'''multiply two matrices together'''
if result_name is None:
result_name = '{}*{}'.format(lhs_name, rhs_name)
lhs_id = self._name2id(lhs_name)
rhs_id = self._name2id(rhs_name)
rhs_rows, rhs_cols = self.matrix_dims(rhs_name)
lhs_rows, lhs_cols = self.matrix_dims(lhs_name)
if lhs_cols != rhs_rows:
raise ValueError('cannot multiple matrix unless LHS cols = RHS rows')
result_id = self._make_zero(result_name, lhs_rows, rhs_cols)
self.sql(
'''
INSERT INTO element (matrix_id, val, row, col)
SELECT ?, lhs.val * rhs.val, lhs.row, rhs.col
FROM element as lhs JOIN element as rhs ON (lhs.col = rhs.row)
WHERE lhs.matrix_id = ? AND rhs.matrix_id = ?
GROUP BY lhs.row, rhs.col''',
(result_id, lhs_id, rhs_id))
self._clean(result_id)
return result_name
def add(self, lhs_name, rhs_name, result_name=None):
'''add two matrices'''
if result_name is None:
result_name = "({}+{})".format(lhs_name, rhs_name)
lhs_id = self._name2id(lhs_name)
rhs_id = self._name2id(rhs_name)
if self.matrix_dims(lhs_name) != self.matrix_dims(rhs_name):
raise ValueError('cannot add matrices of different sizes')
result_id = self._make_zero(result_name, *self.matrix_dims(lhs_name))
self.sql( # add values that are in both matrices
'''
INSERT INTO element (matrix_id, val, row, col)
WITH lhs as (SELECT * FROM element WHERE matrix_id = ?),
rhs as (SELECT * FROM element WHERE matrix_id = ?)
SELECT ?, val, row, col
FROM (
SELECT lhs.val + rhs.val as val, lhs.row as row, lhs.col as col
FROM lhs JOIN rhs ON (lhs.row = rhs.row AND lhs.col = rhs.col)
UNION ALL
SELECT val, row, col FROM lhs WHERE NOT EXISTS (
SELECT 1 FROM rhs WHERE (lhs.row = rhs.row AND lhs.col = rhs.col))
UNION ALL
SELECT val, row, col FROM rhs WHERE NOT EXISTS(
SELECT 1 FROM lhs WHERE (lhs.row = rhs.row AND lhs.col = rhs.col))
)''',
(lhs_id, rhs_id, result_id))
self._clean(result_id)
return result_name
def accumulate(self, base, incr):
'''inline add -- base += incr'''
base = self._name2id(base)
incr = self._name2id(incr)
self.sql( # add the values where indices line up
'''
UPDATE element SET val = val + (
SELECT val FROM element as incr
WHERE incr.matrix_id = :incr AND element.row = incr.row AND element.col = incr.col)
WHERE element.matrix_id = :base AND EXISTS (
SELECT 1 FROM element as incr WHERE
incr.matrix_id = :incr AND element.row = incr.row AND element.col = incr.col)
''',
dict(base=base, incr=incr))
self.sql( # set any values which are not present (implicit 0's)
'''
INSERT INTO element (matrix_id, val, row, col)
SELECT :base, val, row, col FROM element as incr WHERE
incr.matrix_id = :incr AND NOT EXISTS (
SELECT 1 FROM element as base
WHERE base.matrix_id = :base AND base.row = incr.row AND base.col = incr.col)''',
dict(base=base, incr=incr))
self._clean(base) # remove any rows that have become 0
def _check_key(self, matrix_id, row, col):
nrows, ncols = self.matrix_dims(matrix_id)
if nrows <= row or ncols <= col:
raise KeyError('out of size of matrix')
def set_val(self, name, val, row, col):
'''set a value in a matrix'''
matrix_id = self._name2id(name)
self._check_key(matrix_id, row, col)
if val == 0:
self.sql(
'DELETE FROM element WHERE matrix_id = ? AND row = ? AND col = ?',
(matrix_id, row, col))
else:
self.sql(
'INSERT OR REPLACE INTO element (val, matrix_id, row, col) VALUES (?, ?, ?, ?)',
(val, matrix_id, row, col))
def get_val(self, name, row, col):
'''get one value from a matrix'''
matrix_id = self._name2id(name)
self._check_key(matrix_id, row, col)
return (self.sql_col(
'SELECT val FROM element WHERE matrix_id = ? and row = ? and col = ?',
(matrix_id, row, col)) + [0])[0]
"""
def is_eq(self, lhs_name, rhs_name):
'''check if two matrices are the same'''
if self.matrix_dims(lhs_name) != self.matrix_dims(rhs_name):
return False
lhs_id = self._name2id(lhs_name)
rhs_id = self._name2id(rhs_name)
return bool(self.sql_val('''
SELECT
EXISTS(SELECT 1
FROM element as lhs LEFT OUTER JOIN element as rhs ON
(lhs.row = rhs.row AND lhs.col = rhs.col)
WHERE lhs.val IS NULL AND lhs.matrix_id = ? AND rhs.matrix_id = ?) OR
EXISTS(SELECT 1
FROM element as rhs LEFT OUTER JOIN element as lhs ON
(lhs.row = rhs.row AND lhs.col = rhs.col)
WHERE rhs.val IS NULL AND lhs.matrix_id = ? AND rhs.matrix_id = ?)''',
(lhs_id, rhs_id) * 2))
"""
def dump(self, name):
'''dump matrix to list-of-rows'''
matrix_id = self._name2id(name)
nrows, ncols = self.matrix_dims(name)
rows = []
for i in range(nrows):
row = []
rows.append(row)
for j in range(ncols):
row.append((self.sql_col(
'SELECT val FROM element WHERE matrix_id = ? AND row = ? AND col = ?',
(matrix_id, i, j)) + [0])[0])
return rows
def delete(self, name):
matrix_id = self._name2id(name)
self.sql('DELETE FROM matrix WHERE ROWID = ?', (name,))
self.sql('DELETE FROM element WHERE matrix_id = ?', (matrix_id,))
def zero_out_matrix(self, name):
matrix_id = self._name2id(name)
self.sql('DELETE FROM element WHERE matrix_id = ?', (matrix_id,))
class _Cleaner(object):
def __init__(self, db, id):
self.db, self.id = db, id
def __call__(self, ref):
self.db.delete(self.id)
class Matrix(object):
def __init__(self, db, name):
self.db, self.id = db, db._name2id(name)
self._cleaner = weakref.ref(self, _Cleaner(db, self.id))
def __add__(self, other):
assert type(self) is type(other)
assert self.db is other.db
return Matrix(self.db, self.db.add(self.id, other.id))
def __iadd__(self, other):
self.db.accumulate(self.id, other.id)
return self
def __mul__(self, other):
assert type(self) is type(other)
assert self.db is other.db
return Matrix(self.db, self.db.mult(self.id, other.id))
def matrix_cost(self):
return self.db.matrix_cost(self.id)
def __setitem__(self, key, value):
row, col = key
self.db.set_val(self.id, value, row, col)
def __getitem__(self, key):
row, col = key
return self.db.get_val(self.id, row, col)
def __delitem__(self, key):
row, col = key
self.db.set_val(self.id, 0, row, col)
def as_rowlist(self):
return self.db.dump(self.id)
@property
def name(self):
return self.db.sql_val('SELECT name FROM matrix WHERE ROWID = ?', (self.id,))
def __eq__(self, other):
assert type(self) is type(other) and self.db is other.db
return self.db.dump(self.id) == self.db.dump(other.id)
SCHEMA = '''
CREATE TABLE matrix (
name TEXT NOT NULL,
nrows INTEGER NOT NULL,
ncols INTEGER NOT NULL
);
CREATE TABLE element (
matrix_id INTEGER NOT NULL,
row INTEGER NOT NULL,
col INTEGER NOT NULL,
val INTEGER NOT NULL DEFAULT 0
);
CREATE INDEX element_row_col ON element(row, col);
CREATE UNIQUE INDEX coords ON element(matrix_id, row, col);
'''
def _m(db, datastring):
size = int(len(datastring) ** 0.5)
rows = []
for char in datastring:
if len(rows) % size == 0:
rows.append([])
row = rows[-1]
val = int(char)
row.append(val)
return db.rows2matrix(datastring, rows)
def _s(matrix):
rows = matrix.as_rowlist()
return ''.join(''.join(str(val) for val in row) for row in rows)
def test():
db = MatrixDB()
id10 = db.make_identity('id10', 10)
zero10 = db.make_zero('zero10', 10, 10)
id10.as_rowlist() == id10.as_rowlist()
assert (id10 + id10).matrix_cost() == id10.matrix_cost()
assert id10 * id10 == id10
val = id10 + id10
val += val
assert val * id10 == val
assert val + zero10 == val
print _s( _m(db, '0100') + _m(db, '0111')), _s(_m(db, '0111') + _m(db, '0100'))
assert _m(db, '0100') + _m(db, '0111') == _m(db, '0211')
class DiGraphTC(object):
'''a digraph with transitive closure'''
def __init__(self, num_nodes):
self.db = MatrixDB()
self.adjacency = self.db.make_zero('A', num_nodes, num_nodes)
self.transitive = self.db.make_identity('T', num_nodes)
self.num_nodes = num_nodes
def add_edge(self, parent, child):
self.adjacency[parent, child] = 1 # add the edge
self._increment2(parent, child, 1)
def remove_edge(self, parent, child):
assert self.adjacency[parent, child] == 1
self.adjacency[parent, child] = 0
self._increment2(parent, child, -1)
def ancestor_of(self, ancestor, descendent):
return bool(self.transitive[ancestor, descendent])
def _increment(self, parent, child, amount):
assert amount in (1, -1)
increment = self.db.make_zero('S', self.num_nodes, self.num_nodes)
increment[parent, child] = amount
self.transitive = self.transitive + (self.transitive * increment * self.transitive)
def _increment2(self, parent, child, amount):
assert amount in (1, -1)
Tdelta = self.db.make_zero('Tdelta', self.num_nodes, self.num_nodes)
self.db.sql( # simultaneous T * S * T operation
'''
INSERT INTO element (matrix_id, row, col, val)
SELECT :Tdelta, by_col.row, by_row.col, :amount * sum(by_col.val * by_row.val)
FROM element as by_col, element as by_row
WHERE by_col.col = :parent AND by_row.row = :child AND by_col.matrix_id = :T AND by_row.matrix_id = :T
GROUP BY by_col.row, by_row.col
''',
dict(Tdelta=Tdelta.id, T=self.transitive.id, amount=amount, parent=parent, child=child))
self.transitive += Tdelta
def cost(self):
return self.adjacency.matrix_cost() + self.transitive.matrix_cost()
def test_adjacency():
import pprint
team_mgr = DiGraphTC(5)
team_mgr.add_edge(0, 1)
team_mgr.add_edge(1, 2)
#print "adjacency, 2 children"
#pprint.pprint(team_mgr.adjacency.as_rowlist())
#print "transitive, 2 children"
#pprint.pprint(team_mgr.transitive.as_rowlist())
assert team_mgr.ancestor_of(0, 2)
team_mgr.add_edge(2, 3)
team_mgr.add_edge(3, 4)
assert team_mgr.ancestor_of(0, 4)
team_mgr.remove_edge(2, 3)
assert not team_mgr.ancestor_of(0, 4)
def test_adj_perf():
team_mgr = DiGraphTC(100000)
print "100k vertices, 0 edges: cost", team_mgr.cost()
start = time.time()
for j in range(team_mgr.num_nodes / 10):
if time.time() - start > 20:
print "stopping after 20 seconds..."
break
if str(j)[1:].replace('0', '') == '':
# j is a round number like 10, 20, 100, 2000, 40000
print "100k vertices, {} edges: cost".format(j * 10),
print "nrows", team_mgr.cost(),
print "time {:0.2F}".format(time.time() - start)
for i in range(10):
parent = j * 10 + i
child = j * 10 + i + 1
if parent < team_mgr.num_nodes and child < team_mgr.num_nodes:
team_mgr.add_edge(parent, child)
if __name__ == '__main__':
test()
test_adjacency()
test_adj_perf()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment