Last active
December 13, 2018 23:44
-
-
Save kurtbrose/fad1baf1b4cff07f9295fb2377d76663 to your computer and use it in GitHub Desktop.
matrices in a sqlite database (mostly as a prototype for linear algebra in larger SQL databases)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| ''' | |
| implementation of some linear algebra in SQL | |
| ''' | |
| import sqlite3 | |
| import time | |
| import weakref | |
| _MISSING = object() | |
| class MatrixDB(object): | |
| ''' | |
| Serves two purposes -- | |
| 1 - stores underlying data | |
| 2 - serves as a factory for Matrix instances | |
| matrices are stored as sparse -- missing elements are implicit 0 | |
| ''' | |
| def __init__(self): | |
| self.db = sqlite3.connect(':memory:') | |
| for statement in SCHEMA.split(';'): | |
| statement = statement.strip() | |
| if statement: | |
| self.db.execute(statement) | |
| def execute(self, sql): | |
| '''pass-through''' | |
| return self.db.execute(sql) | |
| def sql(self, sql, args=None): | |
| '''run sql against underlying DB, fetch and returns results''' | |
| if args is None: | |
| return self.db.execute(sql).fetchall() | |
| return self.db.execute(sql, args).fetchall() | |
| def sql_val(self, sql, args=None, default=_MISSING): | |
| '''run SELECT sql that returns single value''' | |
| result = self.sql(sql, args) | |
| if not result: | |
| if default is not _MISSING: | |
| return default | |
| raise ValueError("sql returned no rows", sql, args) | |
| return result[0][0] | |
| def sql_col(self, sql, args=None): | |
| '''run SELECT and return [a, b, c] instead of [(a,), (b,), (c,)]''' | |
| return list(sum(self.sql(sql, args), ())) | |
| def sql_row(self, sql, args=None): | |
| '''run SELECT and return [a, b, c] instead of [(a, b, c)]''' | |
| return self.sql(sql, args)[0] | |
| def make_zero(self, name, nrows, ncols): | |
| '''create a zero matrix''' | |
| return Matrix(self, self._make_zero(name, nrows, ncols)) | |
| def _make_zero(self, name, nrows, ncols): | |
| self.sql('INSERT INTO matrix (name, nrows, ncols) VALUES (?, ?, ?)', (name, nrows, ncols)) | |
| return self.sql_val('SELECT max(ROWID) FROM matrix') | |
| def make_identity(self, name, size): | |
| '''create identity matrix''' | |
| matrix_id = self._make_zero(name, size, size) | |
| for i in range(size): | |
| self.sql( | |
| 'INSERT INTO element (matrix_id, row, col, val) VALUES (?, ?, ?, 1)', | |
| (matrix_id, i, i)) | |
| return Matrix(self, matrix_id) | |
| def rows2matrix(self, name, rowlist): | |
| if not rowlist: | |
| raise ValueError('cannot construct 0x0 matrix') | |
| if len(set([len(row) for row in rowlist])) != 1: | |
| raise ValueError('all rows must have the same length') | |
| nrows = len(rowlist) | |
| ncols = len(rowlist[0]) | |
| matrix_id = self._make_zero(name, nrows, ncols) | |
| for row in range(nrows): | |
| for col in range(ncols): | |
| self.set_val(matrix_id, rowlist[row][col], row, col) | |
| return Matrix(self, matrix_id) | |
| def _name2id(self, name): | |
| if type(name) in (int, long): | |
| return name | |
| return self.sql_val('SELECT ROWID FROM matrix WHERE name = ?', (name,)) | |
| def matrix_dims(self, name): | |
| '''return (nrows, ncols)''' | |
| matrix_id = self._name2id(name) | |
| return self.sql_row('SELECT nrows, ncols FROM matrix WHERE ROWID = ?', (matrix_id,)) | |
| def matrix_cost(self, name): | |
| '''return the number of rows in the database used to represent the matrix''' | |
| matrix_id = self._name2id(name) | |
| return self.sql_val('SELECT count(*) FROM element WHERE matrix_id = ?', (matrix_id,)) | |
| def _clean(self, matrix_id): | |
| self.sql('DELETE FROM element WHERE matrix_id = ? AND val = 0', (matrix_id,)) | |
| def mult(self, lhs_name, rhs_name, result_name=None): | |
| '''multiply two matrices together''' | |
| if result_name is None: | |
| result_name = '{}*{}'.format(lhs_name, rhs_name) | |
| lhs_id = self._name2id(lhs_name) | |
| rhs_id = self._name2id(rhs_name) | |
| rhs_rows, rhs_cols = self.matrix_dims(rhs_name) | |
| lhs_rows, lhs_cols = self.matrix_dims(lhs_name) | |
| if lhs_cols != rhs_rows: | |
| raise ValueError('cannot multiple matrix unless LHS cols = RHS rows') | |
| result_id = self._make_zero(result_name, lhs_rows, rhs_cols) | |
| self.sql( | |
| ''' | |
| INSERT INTO element (matrix_id, val, row, col) | |
| SELECT ?, lhs.val * rhs.val, lhs.row, rhs.col | |
| FROM element as lhs JOIN element as rhs ON (lhs.col = rhs.row) | |
| WHERE lhs.matrix_id = ? AND rhs.matrix_id = ? | |
| GROUP BY lhs.row, rhs.col''', | |
| (result_id, lhs_id, rhs_id)) | |
| self._clean(result_id) | |
| return result_name | |
| def add(self, lhs_name, rhs_name, result_name=None): | |
| '''add two matrices''' | |
| if result_name is None: | |
| result_name = "({}+{})".format(lhs_name, rhs_name) | |
| lhs_id = self._name2id(lhs_name) | |
| rhs_id = self._name2id(rhs_name) | |
| if self.matrix_dims(lhs_name) != self.matrix_dims(rhs_name): | |
| raise ValueError('cannot add matrices of different sizes') | |
| result_id = self._make_zero(result_name, *self.matrix_dims(lhs_name)) | |
| self.sql( # add values that are in both matrices | |
| ''' | |
| INSERT INTO element (matrix_id, val, row, col) | |
| WITH lhs as (SELECT * FROM element WHERE matrix_id = ?), | |
| rhs as (SELECT * FROM element WHERE matrix_id = ?) | |
| SELECT ?, val, row, col | |
| FROM ( | |
| SELECT lhs.val + rhs.val as val, lhs.row as row, lhs.col as col | |
| FROM lhs JOIN rhs ON (lhs.row = rhs.row AND lhs.col = rhs.col) | |
| UNION ALL | |
| SELECT val, row, col FROM lhs WHERE NOT EXISTS ( | |
| SELECT 1 FROM rhs WHERE (lhs.row = rhs.row AND lhs.col = rhs.col)) | |
| UNION ALL | |
| SELECT val, row, col FROM rhs WHERE NOT EXISTS( | |
| SELECT 1 FROM lhs WHERE (lhs.row = rhs.row AND lhs.col = rhs.col)) | |
| )''', | |
| (lhs_id, rhs_id, result_id)) | |
| self._clean(result_id) | |
| return result_name | |
| def accumulate(self, base, incr): | |
| '''inline add -- base += incr''' | |
| base = self._name2id(base) | |
| incr = self._name2id(incr) | |
| self.sql( # add the values where indices line up | |
| ''' | |
| UPDATE element SET val = val + ( | |
| SELECT val FROM element as incr | |
| WHERE incr.matrix_id = :incr AND element.row = incr.row AND element.col = incr.col) | |
| WHERE element.matrix_id = :base AND EXISTS ( | |
| SELECT 1 FROM element as incr WHERE | |
| incr.matrix_id = :incr AND element.row = incr.row AND element.col = incr.col) | |
| ''', | |
| dict(base=base, incr=incr)) | |
| self.sql( # set any values which are not present (implicit 0's) | |
| ''' | |
| INSERT INTO element (matrix_id, val, row, col) | |
| SELECT :base, val, row, col FROM element as incr WHERE | |
| incr.matrix_id = :incr AND NOT EXISTS ( | |
| SELECT 1 FROM element as base | |
| WHERE base.matrix_id = :base AND base.row = incr.row AND base.col = incr.col)''', | |
| dict(base=base, incr=incr)) | |
| self._clean(base) # remove any rows that have become 0 | |
| def _check_key(self, matrix_id, row, col): | |
| nrows, ncols = self.matrix_dims(matrix_id) | |
| if nrows <= row or ncols <= col: | |
| raise KeyError('out of size of matrix') | |
| def set_val(self, name, val, row, col): | |
| '''set a value in a matrix''' | |
| matrix_id = self._name2id(name) | |
| self._check_key(matrix_id, row, col) | |
| if val == 0: | |
| self.sql( | |
| 'DELETE FROM element WHERE matrix_id = ? AND row = ? AND col = ?', | |
| (matrix_id, row, col)) | |
| else: | |
| self.sql( | |
| 'INSERT OR REPLACE INTO element (val, matrix_id, row, col) VALUES (?, ?, ?, ?)', | |
| (val, matrix_id, row, col)) | |
| def get_val(self, name, row, col): | |
| '''get one value from a matrix''' | |
| matrix_id = self._name2id(name) | |
| self._check_key(matrix_id, row, col) | |
| return (self.sql_col( | |
| 'SELECT val FROM element WHERE matrix_id = ? and row = ? and col = ?', | |
| (matrix_id, row, col)) + [0])[0] | |
| """ | |
| def is_eq(self, lhs_name, rhs_name): | |
| '''check if two matrices are the same''' | |
| if self.matrix_dims(lhs_name) != self.matrix_dims(rhs_name): | |
| return False | |
| lhs_id = self._name2id(lhs_name) | |
| rhs_id = self._name2id(rhs_name) | |
| return bool(self.sql_val(''' | |
| SELECT | |
| EXISTS(SELECT 1 | |
| FROM element as lhs LEFT OUTER JOIN element as rhs ON | |
| (lhs.row = rhs.row AND lhs.col = rhs.col) | |
| WHERE lhs.val IS NULL AND lhs.matrix_id = ? AND rhs.matrix_id = ?) OR | |
| EXISTS(SELECT 1 | |
| FROM element as rhs LEFT OUTER JOIN element as lhs ON | |
| (lhs.row = rhs.row AND lhs.col = rhs.col) | |
| WHERE rhs.val IS NULL AND lhs.matrix_id = ? AND rhs.matrix_id = ?)''', | |
| (lhs_id, rhs_id) * 2)) | |
| """ | |
| def dump(self, name): | |
| '''dump matrix to list-of-rows''' | |
| matrix_id = self._name2id(name) | |
| nrows, ncols = self.matrix_dims(name) | |
| rows = [] | |
| for i in range(nrows): | |
| row = [] | |
| rows.append(row) | |
| for j in range(ncols): | |
| row.append((self.sql_col( | |
| 'SELECT val FROM element WHERE matrix_id = ? AND row = ? AND col = ?', | |
| (matrix_id, i, j)) + [0])[0]) | |
| return rows | |
| def delete(self, name): | |
| matrix_id = self._name2id(name) | |
| self.sql('DELETE FROM matrix WHERE ROWID = ?', (name,)) | |
| self.sql('DELETE FROM element WHERE matrix_id = ?', (matrix_id,)) | |
| def zero_out_matrix(self, name): | |
| matrix_id = self._name2id(name) | |
| self.sql('DELETE FROM element WHERE matrix_id = ?', (matrix_id,)) | |
| class _Cleaner(object): | |
| def __init__(self, db, id): | |
| self.db, self.id = db, id | |
| def __call__(self, ref): | |
| self.db.delete(self.id) | |
| class Matrix(object): | |
| def __init__(self, db, name): | |
| self.db, self.id = db, db._name2id(name) | |
| self._cleaner = weakref.ref(self, _Cleaner(db, self.id)) | |
| def __add__(self, other): | |
| assert type(self) is type(other) | |
| assert self.db is other.db | |
| return Matrix(self.db, self.db.add(self.id, other.id)) | |
| def __iadd__(self, other): | |
| self.db.accumulate(self.id, other.id) | |
| return self | |
| def __mul__(self, other): | |
| assert type(self) is type(other) | |
| assert self.db is other.db | |
| return Matrix(self.db, self.db.mult(self.id, other.id)) | |
| def matrix_cost(self): | |
| return self.db.matrix_cost(self.id) | |
| def __setitem__(self, key, value): | |
| row, col = key | |
| self.db.set_val(self.id, value, row, col) | |
| def __getitem__(self, key): | |
| row, col = key | |
| return self.db.get_val(self.id, row, col) | |
| def __delitem__(self, key): | |
| row, col = key | |
| self.db.set_val(self.id, 0, row, col) | |
| def as_rowlist(self): | |
| return self.db.dump(self.id) | |
| @property | |
| def name(self): | |
| return self.db.sql_val('SELECT name FROM matrix WHERE ROWID = ?', (self.id,)) | |
| def __eq__(self, other): | |
| assert type(self) is type(other) and self.db is other.db | |
| return self.db.dump(self.id) == self.db.dump(other.id) | |
| SCHEMA = ''' | |
| CREATE TABLE matrix ( | |
| name TEXT NOT NULL, | |
| nrows INTEGER NOT NULL, | |
| ncols INTEGER NOT NULL | |
| ); | |
| CREATE TABLE element ( | |
| matrix_id INTEGER NOT NULL, | |
| row INTEGER NOT NULL, | |
| col INTEGER NOT NULL, | |
| val INTEGER NOT NULL DEFAULT 0 | |
| ); | |
| CREATE INDEX element_row_col ON element(row, col); | |
| CREATE UNIQUE INDEX coords ON element(matrix_id, row, col); | |
| ''' | |
| def _m(db, datastring): | |
| size = int(len(datastring) ** 0.5) | |
| rows = [] | |
| for char in datastring: | |
| if len(rows) % size == 0: | |
| rows.append([]) | |
| row = rows[-1] | |
| val = int(char) | |
| row.append(val) | |
| return db.rows2matrix(datastring, rows) | |
| def _s(matrix): | |
| rows = matrix.as_rowlist() | |
| return ''.join(''.join(str(val) for val in row) for row in rows) | |
| def test(): | |
| db = MatrixDB() | |
| id10 = db.make_identity('id10', 10) | |
| zero10 = db.make_zero('zero10', 10, 10) | |
| id10.as_rowlist() == id10.as_rowlist() | |
| assert (id10 + id10).matrix_cost() == id10.matrix_cost() | |
| assert id10 * id10 == id10 | |
| val = id10 + id10 | |
| val += val | |
| assert val * id10 == val | |
| assert val + zero10 == val | |
| print _s( _m(db, '0100') + _m(db, '0111')), _s(_m(db, '0111') + _m(db, '0100')) | |
| assert _m(db, '0100') + _m(db, '0111') == _m(db, '0211') | |
| class DiGraphTC(object): | |
| '''a digraph with transitive closure''' | |
| def __init__(self, num_nodes): | |
| self.db = MatrixDB() | |
| self.adjacency = self.db.make_zero('A', num_nodes, num_nodes) | |
| self.transitive = self.db.make_identity('T', num_nodes) | |
| self.num_nodes = num_nodes | |
| def add_edge(self, parent, child): | |
| self.adjacency[parent, child] = 1 # add the edge | |
| self._increment2(parent, child, 1) | |
| def remove_edge(self, parent, child): | |
| assert self.adjacency[parent, child] == 1 | |
| self.adjacency[parent, child] = 0 | |
| self._increment2(parent, child, -1) | |
| def ancestor_of(self, ancestor, descendent): | |
| return bool(self.transitive[ancestor, descendent]) | |
| def _increment(self, parent, child, amount): | |
| assert amount in (1, -1) | |
| increment = self.db.make_zero('S', self.num_nodes, self.num_nodes) | |
| increment[parent, child] = amount | |
| self.transitive = self.transitive + (self.transitive * increment * self.transitive) | |
| def _increment2(self, parent, child, amount): | |
| assert amount in (1, -1) | |
| Tdelta = self.db.make_zero('Tdelta', self.num_nodes, self.num_nodes) | |
| self.db.sql( # simultaneous T * S * T operation | |
| ''' | |
| INSERT INTO element (matrix_id, row, col, val) | |
| SELECT :Tdelta, by_col.row, by_row.col, :amount * sum(by_col.val * by_row.val) | |
| FROM element as by_col, element as by_row | |
| WHERE by_col.col = :parent AND by_row.row = :child AND by_col.matrix_id = :T AND by_row.matrix_id = :T | |
| GROUP BY by_col.row, by_row.col | |
| ''', | |
| dict(Tdelta=Tdelta.id, T=self.transitive.id, amount=amount, parent=parent, child=child)) | |
| self.transitive += Tdelta | |
| def cost(self): | |
| return self.adjacency.matrix_cost() + self.transitive.matrix_cost() | |
| def test_adjacency(): | |
| import pprint | |
| team_mgr = DiGraphTC(5) | |
| team_mgr.add_edge(0, 1) | |
| team_mgr.add_edge(1, 2) | |
| #print "adjacency, 2 children" | |
| #pprint.pprint(team_mgr.adjacency.as_rowlist()) | |
| #print "transitive, 2 children" | |
| #pprint.pprint(team_mgr.transitive.as_rowlist()) | |
| assert team_mgr.ancestor_of(0, 2) | |
| team_mgr.add_edge(2, 3) | |
| team_mgr.add_edge(3, 4) | |
| assert team_mgr.ancestor_of(0, 4) | |
| team_mgr.remove_edge(2, 3) | |
| assert not team_mgr.ancestor_of(0, 4) | |
| def test_adj_perf(): | |
| team_mgr = DiGraphTC(100000) | |
| print "100k vertices, 0 edges: cost", team_mgr.cost() | |
| start = time.time() | |
| for j in range(team_mgr.num_nodes / 10): | |
| if time.time() - start > 20: | |
| print "stopping after 20 seconds..." | |
| break | |
| if str(j)[1:].replace('0', '') == '': | |
| # j is a round number like 10, 20, 100, 2000, 40000 | |
| print "100k vertices, {} edges: cost".format(j * 10), | |
| print "nrows", team_mgr.cost(), | |
| print "time {:0.2F}".format(time.time() - start) | |
| for i in range(10): | |
| parent = j * 10 + i | |
| child = j * 10 + i + 1 | |
| if parent < team_mgr.num_nodes and child < team_mgr.num_nodes: | |
| team_mgr.add_edge(parent, child) | |
| if __name__ == '__main__': | |
| test() | |
| test_adjacency() | |
| test_adj_perf() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment