Skip to content

Instantly share code, notes, and snippets.

@ricardoV94
Last active October 14, 2025 10:25
Show Gist options
  • Save ricardoV94/4e57e311db860bf5264267a86553e17c to your computer and use it in GitHub Desktop.
Save ricardoV94/4e57e311db860bf5264267a86553e17c to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "1a5344cc-2a05-46a5-b31b-29ec3bd63b5a",
"metadata": {},
"outputs": [],
"source": [
"from collections import namedtuple\n",
"\n",
"import numpy as np\n",
"from scipy import sparse"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "eda67807-8c23-49ae-af20-6ce470633e6d",
"metadata": {},
"outputs": [],
"source": [
"mats = [sparse.random(3, 3, density=0.3, format=\"csr\") for i in range(10)]"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "1fc59e2d-27da-4774-ac3d-0b7b86c0662f",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"matrix([[0. , 0.70183216, 0. ],\n",
" [0. , 0. , 0.28669397],\n",
" [0. , 0.78484662, 0. ]])"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"mats[1].todense()"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "94ff5abb-c47f-408d-95cd-9c3c6415b9a1",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(array([0.70183216, 0.28669397, 0.78484662]),\n",
" array([1, 2, 1], dtype=int32),\n",
" array([0, 1, 2, 3], dtype=int32),\n",
" 3)"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"mats[1].data, mats[1].indices, mats[1].indptr, mats[1].nnz"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "548d87bd-337c-43fe-89de-8d7ef300dcff",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"BatchCSRMatrix(data=array([0.05210203, 0.91588562, 0.72053483, 0.70183216, 0.28669397,\n",
" 0.78484662, 0.40953651, 0.9790061 , 0.36297553, 0.52134413,\n",
" 0.58894459, 0.88231634, 0.96920377, 0.16951942, 0.23246895,\n",
" 0.80660814, 0.59596091, 0.99579766, 0.3800977 , 0.09521281,\n",
" 0.70456014, 0.46268364, 0.27507424, 0.34334122, 0.07747538,\n",
" 0.77611234, 0.95426515, 0.31211267, 0.53334523, 0.07058711]), indices=array([1, 1, 0, 1, 2, 1, 0, 2, 0, 0, 1, 2, 0, 0, 2, 2, 1, 2, 0, 1, 1, 0,\n",
" 1, 2, 0, 1, 2, 0, 1, 1], dtype=int32), indptr=array([0, 1, 2, 3, 0, 1, 2, 3, 0, 0, 2, 3, 0, 0, 3, 3, 0, 1, 3, 3, 0, 0,\n",
" 1, 3, 0, 1, 2, 3, 0, 0, 2, 3, 0, 0, 3, 3, 0, 0, 2, 3], dtype=int32), cum_nnz=array([ 0, 3, 6, 9, 12, 15, 18, 21, 24, 27, 30]), shape=(10, 3, 3), dtype=dtype('float64'), kind='csr')"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"BatchMatrix = namedtuple(\"BatchCSRMatrix\", (\"data\", \"indices\", \"indptr\", \"cum_nnz\", \"shape\", \"dtype\", \"kind\"))\n",
"\n",
"batch_matrix = BatchMatrix(\n",
" data=np.concatenate([m.data for m in mats]),\n",
" indices=np.concatenate([m.indices for m in mats]),\n",
" indptr=np.concatenate([m.indptr for m in mats]),\n",
" cum_nnz=np.cumsum(np.stack([0, *[m.nnz for m in mats]])),\n",
" shape=(len(mats), *mats[0].shape),\n",
" dtype=mats[0].dtype,\n",
" kind=\"csr\",\n",
")\n",
"\n",
"batch_matrix"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "2e84f8f1-b646-45ac-9440-c26d20bb1d45",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(array([0, 1, 2, 3], dtype=int32),\n",
" array([1, 2, 1], dtype=int32),\n",
" array([0.70183216, 0.28669397, 0.78484662]))"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"m = batch_matrix.shape[-2]\n",
"i=1\n",
"indptr = batch_matrix.indptr[(m+1) * i: (m+1) * (i + 1)]\n",
"indices = batch_matrix.indices[batch_matrix.cum_nnz[i]: batch_matrix.cum_nnz[i + 1]]\n",
"data = batch_matrix.data[batch_matrix.cum_nnz[i]: batch_matrix.cum_nnz[i + 1]]\n",
"indptr, indices, data"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "dccdc04c-137d-4236-9c58-71c5e7156cb0",
"metadata": {},
"outputs": [],
"source": [
"x = sparse.csr_matrix((data, indices, indptr), shape=batch_matrix.shape[-2:], copy=False)\n",
"np.testing.assert_allclose(x.todense(), mats[i].todense())"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "5d373e08-da6b-406e-91a9-79fed41ecff1",
"metadata": {},
"outputs": [],
"source": [
"def matrix_transpose(bmat):\n",
" # A CSC matrix is equivalent to a transpose CSR with the same data fields\n",
" new_kind = \"csc\" if bmat.kind == \"csr\" else \"csr\"\n",
" return bmat._replace(kind=new_kind)\n",
"\n",
"def dot(bmat, v):\n",
" batch, m, n = bmat.shape\n",
" mat_kind = (sparse.csr_matrix if bmat.kind == \"csr\" else sparse.csc_matrix)\n",
" temp_mat = mat_kind((m, n), dtype=bmat.dtype)\n",
" res_mat = np.empty((batch, m), dtype=bmat.dtype)\n",
" for i in range(batch):\n",
" indptr = bmat.indptr[(m+1) * i: (m+1) * (i + 1)]\n",
" start, end = bmat.cum_nnz[i: i+2]\n",
" indices = bmat.indices[start: end]\n",
" data = bmat.data[start: end]\n",
" temp_mat.data, temp_mat.indices, temp_mat.indptr = data, indices, indptr\n",
" res_mat[i] = temp_mat @ v\n",
" return res_mat"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "a8acbae2-7f7c-4bdc-9955-e0336d21539d",
"metadata": {},
"outputs": [],
"source": [
"mat_concat = np.concatenate([np.array(m.todense())[None] for m in mats])"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "f3828dd5-eab8-49d6-9687-c88401aedf76",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[0.05210203, 0.91588562, 0. ],\n",
" [0.70183216, 0.57338794, 0.78484662],\n",
" [0. , 1.9580122 , 0. ],\n",
" [0. , 2.35357727, 0. ],\n",
" [0. , 0.4649379 , 0. ],\n",
" [0. , 1.61321627, 2.58755623],\n",
" [0. , 0.09521281, 0.70456014],\n",
" [0. , 0.27507424, 0.68668245],\n",
" [0. , 2.68464263, 0. ],\n",
" [0. , 0.53334523, 0.07058711]])"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"v = np.arange(3)\n",
"dot(batch_matrix, v)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "693848f6-1b8a-4faa-8a87-26be9df7a872",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[0.05210203, 0.91588562, 0. ],\n",
" [0.70183216, 0.57338794, 0.78484662],\n",
" [0. , 1.9580122 , 0. ],\n",
" [0. , 2.35357727, 0. ],\n",
" [0. , 0.4649379 , 0. ],\n",
" [0. , 1.61321627, 2.58755623],\n",
" [0. , 0.09521281, 0.70456014],\n",
" [0. , 0.27507424, 0.68668245],\n",
" [0. , 2.68464263, 0. ],\n",
" [0. , 0.53334523, 0.07058711]])"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"mat_concat @ v"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "c4646635-7c7d-404d-af06-1b7da96a4d21",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[1.44106966, 0.91588562, 0. ],\n",
" [0. , 1.56969323, 0.28669397],\n",
" [1.13548757, 0. , 0.9790061 ],\n",
" [0.52134413, 0.58894459, 0.88231634],\n",
" [0.16951942, 0. , 0.23246895],\n",
" [0. , 1.19192182, 2.79820345],\n",
" [0. , 1.5043331 , 0. ],\n",
" [0.46268364, 0.27507424, 0.68668245],\n",
" [0.07747538, 0.77611234, 0.95426515],\n",
" [0.31211267, 0.67451946, 0. ]])"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dot(matrix_transpose(batch_matrix), v)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "81bedbbb-99f1-443c-b74e-a93f1845eba9",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[1.44106966, 0.91588562, 0. ],\n",
" [0. , 1.56969323, 0.28669397],\n",
" [1.13548757, 0. , 0.9790061 ],\n",
" [0.52134413, 0.58894459, 0.88231634],\n",
" [0.16951942, 0. , 0.23246895],\n",
" [0. , 1.19192182, 2.79820345],\n",
" [0. , 1.5043331 , 0. ],\n",
" [0.46268364, 0.27507424, 0.68668245],\n",
" [0.07747538, 0.77611234, 0.95426515],\n",
" [0.31211267, 0.67451946, 0. ]])"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"mat_concat.mT @ v"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9749cc5e-dbaa-4945-99d7-69e02d2b501e",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "pytensor-dev",
"language": "python",
"name": "pytensor-dev"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.8"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment