Last active
October 14, 2025 10:25
-
-
Save ricardoV94/4e57e311db860bf5264267a86553e17c to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "code", | |
| "execution_count": 1, | |
| "id": "1a5344cc-2a05-46a5-b31b-29ec3bd63b5a", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "from collections import namedtuple\n", | |
| "\n", | |
| "import numpy as np\n", | |
| "from scipy import sparse" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 2, | |
| "id": "eda67807-8c23-49ae-af20-6ce470633e6d", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "mats = [sparse.random(3, 3, density=0.3, format=\"csr\") for i in range(10)]" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 3, | |
| "id": "1fc59e2d-27da-4774-ac3d-0b7b86c0662f", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "matrix([[0. , 0.70183216, 0. ],\n", | |
| " [0. , 0. , 0.28669397],\n", | |
| " [0. , 0.78484662, 0. ]])" | |
| ] | |
| }, | |
| "execution_count": 3, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "mats[1].todense()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 4, | |
| "id": "94ff5abb-c47f-408d-95cd-9c3c6415b9a1", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "(array([0.70183216, 0.28669397, 0.78484662]),\n", | |
| " array([1, 2, 1], dtype=int32),\n", | |
| " array([0, 1, 2, 3], dtype=int32),\n", | |
| " 3)" | |
| ] | |
| }, | |
| "execution_count": 4, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "mats[1].data, mats[1].indices, mats[1].indptr, mats[1].nnz" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 5, | |
| "id": "548d87bd-337c-43fe-89de-8d7ef300dcff", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "BatchCSRMatrix(data=array([0.05210203, 0.91588562, 0.72053483, 0.70183216, 0.28669397,\n", | |
| " 0.78484662, 0.40953651, 0.9790061 , 0.36297553, 0.52134413,\n", | |
| " 0.58894459, 0.88231634, 0.96920377, 0.16951942, 0.23246895,\n", | |
| " 0.80660814, 0.59596091, 0.99579766, 0.3800977 , 0.09521281,\n", | |
| " 0.70456014, 0.46268364, 0.27507424, 0.34334122, 0.07747538,\n", | |
| " 0.77611234, 0.95426515, 0.31211267, 0.53334523, 0.07058711]), indices=array([1, 1, 0, 1, 2, 1, 0, 2, 0, 0, 1, 2, 0, 0, 2, 2, 1, 2, 0, 1, 1, 0,\n", | |
| " 1, 2, 0, 1, 2, 0, 1, 1], dtype=int32), indptr=array([0, 1, 2, 3, 0, 1, 2, 3, 0, 0, 2, 3, 0, 0, 3, 3, 0, 1, 3, 3, 0, 0,\n", | |
| " 1, 3, 0, 1, 2, 3, 0, 0, 2, 3, 0, 0, 3, 3, 0, 0, 2, 3], dtype=int32), cum_nnz=array([ 0, 3, 6, 9, 12, 15, 18, 21, 24, 27, 30]), shape=(10, 3, 3), dtype=dtype('float64'), kind='csr')" | |
| ] | |
| }, | |
| "execution_count": 5, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "BatchMatrix = namedtuple(\"BatchCSRMatrix\", (\"data\", \"indices\", \"indptr\", \"cum_nnz\", \"shape\", \"dtype\", \"kind\"))\n", | |
| "\n", | |
| "batch_matrix = BatchMatrix(\n", | |
| " data=np.concatenate([m.data for m in mats]),\n", | |
| " indices=np.concatenate([m.indices for m in mats]),\n", | |
| " indptr=np.concatenate([m.indptr for m in mats]),\n", | |
| " cum_nnz=np.cumsum(np.stack([0, *[m.nnz for m in mats]])),\n", | |
| " shape=(len(mats), *mats[0].shape),\n", | |
| " dtype=mats[0].dtype,\n", | |
| " kind=\"csr\",\n", | |
| ")\n", | |
| "\n", | |
| "batch_matrix" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 6, | |
| "id": "2e84f8f1-b646-45ac-9440-c26d20bb1d45", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "(array([0, 1, 2, 3], dtype=int32),\n", | |
| " array([1, 2, 1], dtype=int32),\n", | |
| " array([0.70183216, 0.28669397, 0.78484662]))" | |
| ] | |
| }, | |
| "execution_count": 6, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "m = batch_matrix.shape[-2]\n", | |
| "i=1\n", | |
| "indptr = batch_matrix.indptr[(m+1) * i: (m+1) * (i + 1)]\n", | |
| "indices = batch_matrix.indices[batch_matrix.cum_nnz[i]: batch_matrix.cum_nnz[i + 1]]\n", | |
| "data = batch_matrix.data[batch_matrix.cum_nnz[i]: batch_matrix.cum_nnz[i + 1]]\n", | |
| "indptr, indices, data" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 7, | |
| "id": "dccdc04c-137d-4236-9c58-71c5e7156cb0", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "x = sparse.csr_matrix((data, indices, indptr), shape=batch_matrix.shape[-2:], copy=False)\n", | |
| "np.testing.assert_allclose(x.todense(), mats[i].todense())" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 8, | |
| "id": "5d373e08-da6b-406e-91a9-79fed41ecff1", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "def matrix_transpose(bmat):\n", | |
| " # A CSC matrix is equivalent to a transpose CSR with the same data fields\n", | |
| " new_kind = \"csc\" if bmat.kind == \"csr\" else \"csr\"\n", | |
| " return bmat._replace(kind=new_kind)\n", | |
| "\n", | |
| "def dot(bmat, v):\n", | |
| " batch, m, n = bmat.shape\n", | |
| " mat_kind = (sparse.csr_matrix if bmat.kind == \"csr\" else sparse.csc_matrix)\n", | |
| " temp_mat = mat_kind((m, n), dtype=bmat.dtype)\n", | |
| " res_mat = np.empty((batch, m), dtype=bmat.dtype)\n", | |
| " for i in range(batch):\n", | |
| " indptr = bmat.indptr[(m+1) * i: (m+1) * (i + 1)]\n", | |
| " start, end = bmat.cum_nnz[i: i+2]\n", | |
| " indices = bmat.indices[start: end]\n", | |
| " data = bmat.data[start: end]\n", | |
| " temp_mat.data, temp_mat.indices, temp_mat.indptr = data, indices, indptr\n", | |
| " res_mat[i] = temp_mat @ v\n", | |
| " return res_mat" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 9, | |
| "id": "a8acbae2-7f7c-4bdc-9955-e0336d21539d", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "mat_concat = np.concatenate([np.array(m.todense())[None] for m in mats])" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 10, | |
| "id": "f3828dd5-eab8-49d6-9687-c88401aedf76", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "array([[0.05210203, 0.91588562, 0. ],\n", | |
| " [0.70183216, 0.57338794, 0.78484662],\n", | |
| " [0. , 1.9580122 , 0. ],\n", | |
| " [0. , 2.35357727, 0. ],\n", | |
| " [0. , 0.4649379 , 0. ],\n", | |
| " [0. , 1.61321627, 2.58755623],\n", | |
| " [0. , 0.09521281, 0.70456014],\n", | |
| " [0. , 0.27507424, 0.68668245],\n", | |
| " [0. , 2.68464263, 0. ],\n", | |
| " [0. , 0.53334523, 0.07058711]])" | |
| ] | |
| }, | |
| "execution_count": 10, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "v = np.arange(3)\n", | |
| "dot(batch_matrix, v)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 11, | |
| "id": "693848f6-1b8a-4faa-8a87-26be9df7a872", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "array([[0.05210203, 0.91588562, 0. ],\n", | |
| " [0.70183216, 0.57338794, 0.78484662],\n", | |
| " [0. , 1.9580122 , 0. ],\n", | |
| " [0. , 2.35357727, 0. ],\n", | |
| " [0. , 0.4649379 , 0. ],\n", | |
| " [0. , 1.61321627, 2.58755623],\n", | |
| " [0. , 0.09521281, 0.70456014],\n", | |
| " [0. , 0.27507424, 0.68668245],\n", | |
| " [0. , 2.68464263, 0. ],\n", | |
| " [0. , 0.53334523, 0.07058711]])" | |
| ] | |
| }, | |
| "execution_count": 11, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "mat_concat @ v" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 12, | |
| "id": "c4646635-7c7d-404d-af06-1b7da96a4d21", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "array([[1.44106966, 0.91588562, 0. ],\n", | |
| " [0. , 1.56969323, 0.28669397],\n", | |
| " [1.13548757, 0. , 0.9790061 ],\n", | |
| " [0.52134413, 0.58894459, 0.88231634],\n", | |
| " [0.16951942, 0. , 0.23246895],\n", | |
| " [0. , 1.19192182, 2.79820345],\n", | |
| " [0. , 1.5043331 , 0. ],\n", | |
| " [0.46268364, 0.27507424, 0.68668245],\n", | |
| " [0.07747538, 0.77611234, 0.95426515],\n", | |
| " [0.31211267, 0.67451946, 0. ]])" | |
| ] | |
| }, | |
| "execution_count": 12, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "dot(matrix_transpose(batch_matrix), v)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 13, | |
| "id": "81bedbbb-99f1-443c-b74e-a93f1845eba9", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "array([[1.44106966, 0.91588562, 0. ],\n", | |
| " [0. , 1.56969323, 0.28669397],\n", | |
| " [1.13548757, 0. , 0.9790061 ],\n", | |
| " [0.52134413, 0.58894459, 0.88231634],\n", | |
| " [0.16951942, 0. , 0.23246895],\n", | |
| " [0. , 1.19192182, 2.79820345],\n", | |
| " [0. , 1.5043331 , 0. ],\n", | |
| " [0.46268364, 0.27507424, 0.68668245],\n", | |
| " [0.07747538, 0.77611234, 0.95426515],\n", | |
| " [0.31211267, 0.67451946, 0. ]])" | |
| ] | |
| }, | |
| "execution_count": 13, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "mat_concat.mT @ v" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "id": "9749cc5e-dbaa-4945-99d7-69e02d2b501e", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [] | |
| } | |
| ], | |
| "metadata": { | |
| "kernelspec": { | |
| "display_name": "pytensor-dev", | |
| "language": "python", | |
| "name": "pytensor-dev" | |
| }, | |
| "language_info": { | |
| "codemirror_mode": { | |
| "name": "ipython", | |
| "version": 3 | |
| }, | |
| "file_extension": ".py", | |
| "mimetype": "text/x-python", | |
| "name": "python", | |
| "nbconvert_exporter": "python", | |
| "pygments_lexer": "ipython3", | |
| "version": "3.12.8" | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 5 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment