Last active
January 11, 2021 07:44
-
-
Save cheind/35dfb3e67263dfdf7da80366c18db320 to your computer and use it in GitHub Desktop.
Transactional-like (undoable) matrix operations (row delete, column permute) implemented based on permutation matrices
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from itertools import count | |
def perm_matrix(perm_indices): | |
'''Returns the permutation matrix corresponding to given permutation indices | |
Here `perm_indices` defines the permutation order in the following sense: | |
value `j` at index `i` will move row/column `j` of the original matrix to | |
row/column `i`in the permuated matrix P*M/M*P^T. | |
Params | |
------ | |
perm_indices: N | |
permutation order | |
''' | |
N = len(perm_indices) | |
pm = np.empty((N,N), dtype=np.int32) | |
for i,j in enumerate(perm_indices): | |
pm[i] = basis_vec(j, N, dtype=np.int32) | |
return pm | |
def binary_perm_matrix(i, j, N): | |
'''Returns permutation matrix that exchanges row/column i and j.''' | |
ids = np.arange(N) | |
ids[i] = j | |
ids[j] = i | |
return perm_matrix(ids) | |
def basis_vec(i, n, dtype=None): | |
'''Returns the standard basis vector e_i in R^n.''' | |
e = np.zeros(n, dtype=dtype) | |
e[i] = 1 | |
return e | |
class MatrixState: | |
def __init__(self, m): | |
self.R, self.C = m.shape | |
self.m = m | |
self.dr = 0 # Number of deleted rows | |
self.dc = 0 # Number of deleted cols | |
self.rp = np.eye(self.R, dtype=m.dtype) # Sequence of row permutations | |
self.cp = np.eye(self.C, dtype=m.dtype) # Sequence of col permutations | |
self.history = [] | |
@property | |
def matrix(self): | |
'''Returns matrix as represented by the current state''' | |
m = self.rp @ self.m @ self.cp | |
return m[:self.R-self.dr, :self.C-self.dc] | |
@property | |
def indices(self): | |
'''Returns original row and column indices of the current matrix state.''' | |
return np.where(self.rp)[1][:self.R-self.dr], np.where(self.cp.T)[1][:self.C-self.dc] | |
def transaction(self): | |
return MatrixTransaction(self) | |
class UndoableMatrixOp: | |
def apply(self, state): | |
raise NotImplementedError() | |
def undo(self, state): | |
raise NotImplementedError() | |
class SwapRowsOp(UndoableMatrixOp): | |
def __init__(self, i, j): | |
self.ids = (i,j) | |
def apply(self, state): | |
self.p = binary_perm_matrix(self.ids[0], self.ids[1], state.R) | |
state.rp = self.p @ state.rp | |
def undo(self, state): | |
state.rp = self.p.T @ state.rp | |
class SwapColsOp(UndoableMatrixOp): | |
def __init__(self, i, j): | |
self.ids = (i,j) | |
def apply(self, state): | |
self.p = binary_perm_matrix(self.ids[0], self.ids[1], state.C) | |
state.cp = state.cp @ self.p | |
def undo(self, state): | |
state.cp = state.cp @ self.p.T | |
class DeleteOp(UndoableMatrixOp): | |
def __init__(self, ids, rows=True): | |
if isinstance(ids, int): | |
ids = [ids] | |
self.ids = ids | |
self.rows = rows | |
def apply(self, state): | |
self.p = DeleteOp.delete_perm_matrix(state, self.ids, rows=self.rows) | |
if self.rows: | |
state.rp = self.p @ state.rp | |
state.dr += len(self.ids) | |
else: | |
state.cp = state.cp @ self.p | |
state.dc += len(self.ids) | |
def undo(self, state): | |
if self.rows: | |
state.dr -= len(self.ids) | |
state.rp = self.p.T @ state.rp | |
else: | |
state.dc -= len(self.ids) | |
state.cp = state.cp @ self.p.T | |
@staticmethod | |
def delete_perm_matrix(state, ids, rows=True): | |
'''Returns the permutation matrix that moves deleted rows/columns to the end of the array.''' | |
N = state.R if rows else state.C | |
d = state.dr if rows else state.dc | |
pids = np.arange(N).astype(dtype=np.int32) # each entry holds target row index | |
upper = N - d # ignore already deleted ones | |
rcnt = count(upper-1, -1) | |
cnt = count(0, 1) | |
# We reorder the values i 0..upper in that we assign the value i | |
# to index w, where w is chosen from increasing numbers when i is | |
# not in the deleted map, otherwise we select w to be the next possible | |
# index from the back. | |
for i in range(0,upper): | |
w = next(rcnt) if i in ids else next(cnt) | |
pids[w] = i | |
p = perm_matrix(pids) | |
return p if rows else p.T | |
class MatrixTransaction: | |
def __init__(self, matrix_state): | |
self.matrix_state = matrix_state | |
self.committed = None | |
self.ops = None | |
def __enter__(self): | |
self.committed = False | |
self.ops = [] | |
return self | |
def __exit__(self, exc_type, exc_val, exc_tb): | |
if not self.committed: | |
self._undo_all() | |
def commit(self): | |
self.committed = True | |
def swap_rows(self, i, j): | |
return self._apply(SwapRowsOp(i, j)) | |
def swap_cols(self, i, j): | |
return self._apply(SwapColsOp(i, j)) | |
def delete_rows(self, ids): | |
return self._apply(DeleteOp(ids, rows=True)) | |
def delete_cols(self, ids): | |
return self._apply(DeleteOp(ids, rows=False)) | |
def _undo(self): | |
op = self.ops.pop() | |
op.undo(self.matrix_state) | |
return self | |
def _undo_all(self): | |
while len(self.ops) > 0: | |
self._undo() | |
def _apply(self, op): | |
op.apply(self.matrix_state) | |
self.ops.append(op) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from numpy.testing import assert_allclose | |
import matrix_transactions as u | |
def test_matrix_ops(): | |
m = np.arange(9).astype(np.float32).reshape(3,3) | |
ms = u.MatrixState(m) | |
with ms.transaction() as t: | |
t.swap_cols(0,2) | |
t.swap_rows(0,1) | |
t.delete_rows(2) | |
assert_allclose(ms.matrix, [[5,4,3],[2,1,0]]) | |
# no commit | |
assert_allclose(ms.matrix, m) | |
ms = u.MatrixState(m) | |
with ms.transaction() as t: | |
t.swap_cols(0,2) | |
t.swap_rows(0,1) | |
t.delete_rows(2) | |
assert_allclose(ms.matrix, [[5,4,3],[2,1,0]]) | |
t.commit() | |
# no commit | |
assert_allclose(ms.matrix, [[5,4,3],[2,1,0]]) | |
m = np.arange(10).astype(np.float32).reshape(5,2) | |
ms = u.MatrixState(m) | |
with ms.transaction() as t: | |
t.delete_rows([2,3,0]) | |
assert_allclose(ms.matrix, [[2,3],[8,9]]) | |
assert_allclose(ms.indices[0], [1,4]) | |
assert_allclose(ms.indices[1], [0,1]) | |
with ms.transaction() as tt: | |
tt.delete_rows(0) | |
assert_allclose(ms.matrix, [[8,9]]) | |
assert_allclose(ms.indices[0], [4]) | |
assert_allclose(ms.indices[1], [0,1]) | |
assert_allclose(ms.matrix, m) | |
m = np.arange(20).astype(np.float32).reshape(4,5) | |
ms = u.MatrixState(m) | |
with ms.transaction() as t: | |
t.delete_cols([1,2,3]) | |
t.delete_rows([0,2]) | |
assert_allclose(ms.matrix, [[5,9],[15,19]]) | |
assert_allclose(ms.indices[0], [1,3]) | |
assert_allclose(ms.indices[1], [0,4]) | |
assert_allclose(ms.matrix, m) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment