Created
April 5, 2014 13:38
-
-
Save Midnighter/9992103 to your computer and use it in GitHub Desktop.
Testing methods to set the diagonal of a scipy sparse matrix to zero.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"metadata": { | |
"name": "", | |
"signature": "sha256:94ef4128925ef55646ae9c7f9007b8324f938be76b80e78a77664868d845d940" | |
}, | |
"nbformat": 3, | |
"nbformat_minor": 0, | |
"worksheets": [ | |
{ | |
"cells": [ | |
{ | |
"cell_type": "heading", | |
"level": 1, | |
"metadata": {}, | |
"source": [ | |
"Set the Diagonal of a Sparse Matrix to Zero" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"collapsed": false, | |
"input": [ | |
"import scipy\n", | |
"import scipy.sparse as sp" | |
], | |
"language": "python", | |
"metadata": {}, | |
"outputs": [], | |
"prompt_number": 1 | |
}, | |
{ | |
"cell_type": "code", | |
"collapsed": false, | |
"input": [ | |
"one_dim = int(1E03)" | |
], | |
"language": "python", | |
"metadata": {}, | |
"outputs": [], | |
"prompt_number": 2 | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Construct a fairly sparse matrix." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"collapsed": false, | |
"input": [ | |
"mat = sp.csr_matrix((scipy.random.random_sample((one_dim, one_dim)) > 0.8).astype(float))" | |
], | |
"language": "python", | |
"metadata": {}, | |
"outputs": [], | |
"prompt_number": 3 | |
}, | |
{ | |
"cell_type": "code", | |
"collapsed": false, | |
"input": [ | |
"mat.nnz / float(one_dim * one_dim)" | |
], | |
"language": "python", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"metadata": {}, | |
"output_type": "pyout", | |
"prompt_number": 4, | |
"text": [ | |
"0.199934" | |
] | |
} | |
], | |
"prompt_number": 4 | |
}, | |
{ | |
"cell_type": "code", | |
"collapsed": false, | |
"input": [ | |
"(mat.diagonal() > 0.0).sum()" | |
], | |
"language": "python", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"metadata": {}, | |
"output_type": "pyout", | |
"prompt_number": 5, | |
"text": [ | |
"217" | |
] | |
} | |
], | |
"prompt_number": 5 | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"First naive attempt, `setdiag` will also set elements that were not previously present." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"collapsed": false, | |
"input": [ | |
"%%timeit\n", | |
"cpy = mat.copy()\n", | |
"cpy.setdiag(scipy.zeros(one_dim))\n", | |
"assert (cpy.diagonal() > 0.0).sum() == 0" | |
], | |
"language": "python", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"1 loops, best of 3: 816 ms per loop\n" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stderr", | |
"text": [ | |
"/home/moritz/.virtualenvs/test-diag/lib/python2.7/site-packages/scipy/sparse/compressed.py:728: SparseEfficiencyWarning: Changing the sparsity structure of a csr_matrix is expensive. lil_matrix is more efficient.\n", | |
" SparseEfficiencyWarning)\n" | |
] | |
} | |
], | |
"prompt_number": 6 | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"At least convert to a sparse format where modification isn't as expensive." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"collapsed": false, | |
"input": [ | |
"%%timeit\n", | |
"lil = mat.tolil()\n", | |
"lil.setdiag(scipy.zeros(one_dim))\n", | |
"cpy = lil.tocsr()\n", | |
"assert (cpy.diagonal() > 0.0).sum() == 0" | |
], | |
"language": "python", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"10 loops, best of 3: 111 ms per loop\n" | |
] | |
} | |
], | |
"prompt_number": 7 | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Implement a function that sets only diagonal elements that are actually present." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"collapsed": false, | |
"input": [ | |
"def csr_setdiag_val(csr, value=0):\n", | |
" \"\"\"Set all diagonal nonzero elements\n", | |
" (elements currently in the sparsity pattern)\n", | |
" to the given value. Useful to set to 0 mostly.\n", | |
" \"\"\"\n", | |
" if csr.format != \"csr\":\n", | |
" raise ValueError('Matrix given must be of CSR format.')\n", | |
" csr.sort_indices()\n", | |
" pointer = csr.indptr\n", | |
" indices = csr.indices\n", | |
" data = csr.data\n", | |
" for i in range(min(csr.shape)):\n", | |
" ind = indices[pointer[i]: pointer[i + 1]]\n", | |
" j = ind.searchsorted(i)\n", | |
" # matrix has only elements up until diagonal (in row i)\n", | |
" if j == len(ind):\n", | |
" continue\n", | |
" j += pointer[i]\n", | |
" # in case matrix has only elements after diagonal (in row i)\n", | |
" if indices[j] == i:\n", | |
" data[j] = value" | |
], | |
"language": "python", | |
"metadata": {}, | |
"outputs": [], | |
"prompt_number": 8 | |
}, | |
{ | |
"cell_type": "code", | |
"collapsed": false, | |
"input": [ | |
"%%timeit\n", | |
"cpy = mat.copy()\n", | |
"csr_setdiag_val(cpy)\n", | |
"assert (cpy.diagonal() > 0.0).sum() == 0" | |
], | |
"language": "python", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"100 loops, best of 3: 9.87 ms per loop\n" | |
] | |
} | |
], | |
"prompt_number": 9 | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Create a sparse matrix from the diagonal and subtract it from the original." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"collapsed": false, | |
"input": [ | |
"%%timeit\n", | |
"cpy = mat.copy() - sp.dia_matrix((mat.diagonal()[scipy.newaxis, :], [0]), shape=(one_dim, one_dim))\n", | |
"assert (cpy.diagonal() > 0.0).sum() == 0" | |
], | |
"language": "python", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"1000 loops, best of 3: 1.99 ms per loop\n" | |
] | |
} | |
], | |
"prompt_number": 10 | |
} | |
], | |
"metadata": {} | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment