Skip to content

Instantly share code, notes, and snippets.

@cloneofsimo
Created October 31, 2024 09:43
Show Gist options
  • Save cloneofsimo/191bc3cd1f65a9f900f2dc8a6b486fe7 to your computer and use it in GitHub Desktop.
Save cloneofsimo/191bc3cd1f65a9f900f2dc8a6b486fe7 to your computer and use it in GitHub Desktop.
extended_syevjBatched torch
# batch_eigendecomp.py
import torch
from torch.utils.cpp_extension import load_inline
import argparse
import os
import shutil
def clear_cuda_cache():
cache_path = os.path.expanduser('~/.cache/torch_extensions')
if os.path.exists(cache_path):
shutil.rmtree(cache_path)
print(f"Cleared PyTorch extensions cache at {cache_path}")
cpp_source = """
#include <torch/extension.h>
std::vector<torch::Tensor> forward(torch::Tensor input, double atol, int max_sweeps);
"""
cuda_source = """
#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cusolverDn.h>
#include <vector>
#define CHECK_CUDA(err) do { if (err != cudaSuccess) { \
printf("CUDA error: %s at line %d in file %s\\n", \
cudaGetErrorString(err), __LINE__, __FILE__); \
throw std::runtime_error("CUDA error"); \
}} while (0)
#define CHECK_CUSOLVER(err) do { if (err != CUSOLVER_STATUS_SUCCESS) { \
printf("CUSOLVER error: %d at line %d in file %s\\n", \
err, __LINE__, __FILE__); \
throw std::runtime_error("CUSOLVER error"); \
}} while (0)
template<typename scalar_t>
void eigendecomp_kernel_impl(
const torch::Tensor& input,
torch::Tensor& eigenvectors,
torch::Tensor& eigenvalues,
scalar_t atol,
int max_sweeps) {
const auto batch_size = input.size(0);
const auto n = input.size(1);
const auto lda = n;
cusolverDnHandle_t cusolverH = nullptr;
cudaStream_t stream = nullptr;
syevjInfo_t syevj_params = nullptr;
CHECK_CUSOLVER(cusolverDnCreate(&cusolverH));
CHECK_CUDA(cudaStreamCreate(&stream));
CHECK_CUSOLVER(cusolverDnSetStream(cusolverH, stream));
CHECK_CUSOLVER(cusolverDnCreateSyevjInfo(&syevj_params));
const int sort_eig = 0; // Don't sort eigenvalues
const cusolverEigMode_t jobz = CUSOLVER_EIG_MODE_VECTOR;
const cublasFillMode_t uplo = CUBLAS_FILL_MODE_LOWER;
CHECK_CUSOLVER(cusolverDnXsyevjSetTolerance(syevj_params, atol));
CHECK_CUSOLVER(cusolverDnXsyevjSetMaxSweeps(syevj_params, max_sweeps));
CHECK_CUSOLVER(cusolverDnXsyevjSetSortEig(syevj_params, sort_eig));
int lwork = 0;
if (std::is_same<scalar_t, double>::value) {
CHECK_CUSOLVER(cusolverDnDsyevjBatched_bufferSize(
cusolverH, jobz, uplo, n,
reinterpret_cast<double*>(input.data_ptr<scalar_t>()),
lda,
reinterpret_cast<double*>(eigenvalues.data_ptr<scalar_t>()),
&lwork,
syevj_params,
batch_size
));
} else {
CHECK_CUSOLVER(cusolverDnSsyevjBatched_bufferSize(
cusolverH, jobz, uplo, n,
reinterpret_cast<float*>(input.data_ptr<scalar_t>()),
lda,
reinterpret_cast<float*>(eigenvalues.data_ptr<scalar_t>()),
&lwork,
syevj_params,
batch_size
));
}
auto workspace = torch::empty({lwork}, input.options());
auto info = torch::empty({batch_size}, torch::dtype(torch::kInt32).device(input.device()));
eigenvectors.copy_(input);
if (std::is_same<scalar_t, double>::value) {
CHECK_CUSOLVER(cusolverDnDsyevjBatched(
cusolverH, jobz, uplo, n,
reinterpret_cast<double*>(eigenvectors.data_ptr<scalar_t>()),
lda,
reinterpret_cast<double*>(eigenvalues.data_ptr<scalar_t>()),
reinterpret_cast<double*>(workspace.data_ptr<scalar_t>()),
lwork,
reinterpret_cast<int*>(info.data_ptr<int>()),
syevj_params,
batch_size
));
} else {
CHECK_CUSOLVER(cusolverDnSsyevjBatched(
cusolverH, jobz, uplo, n,
reinterpret_cast<float*>(eigenvectors.data_ptr<scalar_t>()),
lda,
reinterpret_cast<float*>(eigenvalues.data_ptr<scalar_t>()),
reinterpret_cast<float*>(workspace.data_ptr<scalar_t>()),
lwork,
reinterpret_cast<int*>(info.data_ptr<int>()),
syevj_params,
batch_size
));
}
CHECK_CUSOLVER(cusolverDnDestroySyevjInfo(syevj_params));
CHECK_CUSOLVER(cusolverDnDestroy(cusolverH));
CHECK_CUDA(cudaStreamDestroy(stream));
}
void eigendecomp_kernel(
const torch::Tensor& input,
torch::Tensor& eigenvectors,
torch::Tensor& eigenvalues,
double atol,
int max_sweeps) {
AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "eigendecomp_kernel", ([&] {
eigendecomp_kernel_impl<scalar_t>(
input, eigenvectors, eigenvalues,
static_cast<scalar_t>(atol), max_sweeps);
}));
}
std::vector<torch::Tensor> forward(torch::Tensor input, double atol, int max_sweeps) {
TORCH_CHECK(input.device().is_cuda(), "input must be a CUDA tensor");
TORCH_CHECK(input.is_contiguous(), "input must be contiguous");
TORCH_CHECK(input.dim() == 3, "Input tensor must be 3D (batch_size, n, n)");
TORCH_CHECK(input.size(1) == input.size(2), "Input tensor must be square matrices");
TORCH_CHECK(input.scalar_type() == torch::kFloat32 || input.scalar_type() == torch::kFloat64,
"Input must be float32 or float64");
auto batch_size = input.size(0);
auto n = input.size(1);
auto eigenvectors = torch::empty_like(input);
auto eigenvalues = torch::empty({batch_size, n}, input.options());
eigendecomp_kernel(input, eigenvectors, eigenvalues, atol, max_sweeps);
return {eigenvectors, eigenvalues};
}
"""
def init_extension(clear_cache=False):
if clear_cache:
clear_cuda_cache()
return load_inline(
name="batch_eigendecomp",
cpp_sources=cpp_source,
cuda_sources=cuda_source,
functions=["forward"],
extra_cuda_cflags=["-O3"],
extra_cflags=["-O3"],
extra_ldflags=["-lcusolver"],
verbose=True,
build_directory=None
)
batch_eigendecomp_cuda = None
def get_extension(clear_cache=False):
global batch_eigendecomp_cuda
if batch_eigendecomp_cuda is None or clear_cache:
batch_eigendecomp_cuda = init_extension(clear_cache)
return batch_eigendecomp_cuda
class BatchEigendecomp(torch.autograd.Function):
@staticmethod
def forward(ctx, input, atol=1e-7, max_sweeps=100):
if not input.is_contiguous():
input = input.contiguous()
return tuple(get_extension().forward(input, atol, max_sweeps))
def batch_eigendecomp(x, atol=1e-7, max_sweeps=100, clear_cache=False):
"""
Computes eigendecomposition for a batch of symmetric matrices.
Args:
x: Input tensor of shape [batch_size, n, n]
Must be symmetric, CUDA tensor in float32 or float64 dtype
atol: Absolute tolerance for convergence (default: 1e-7)
max_sweeps: Maximum number of sweeps for the Jacobi algorithm (default: 100)
clear_cache: If True, clears the PyTorch extensions cache before compiling
Returns:
tuple: (eigenvectors, eigenvalues)
- eigenvectors: tensor of shape [batch_size, n, n]
- eigenvalues: tensor of shape [batch_size, n]
"""
if clear_cache:
get_extension(clear_cache=True)
return BatchEigendecomp.apply(x, atol, max_sweeps)
import time
import numpy as np
import pandas as pd
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--clear-cache', action='store_true', help='Clear PyTorch extensions cache')
parser.add_argument('--atol', type=float, default=1e-6, help='Absolute tolerance')
parser.add_argument('--max-sweeps', type=int, default=10, help='Maximum number of sweeps')
parser.add_argument('--dtype', choices=['float32', 'float64'], default='float32', help='Data type')
args = parser.parse_args()
if args.clear_cache:
clear_cuda_cache()
dtype = torch.float32 if args.dtype == 'float32' else torch.float64
stats = []
for logbs in [6]:
for logmatsize in [9]:
batch_size = 2**logbs
n = 2**logmatsize
X = torch.randn(batch_size, n, n, dtype=dtype, device='cuda:0')
X = X + X.transpose(-2, -1) # Make symmetric
for atol in [1e-1, 1e-3, 1e-5, 1e-7, 1e-9, 1e-11]:
for max_sweeps in [2, 4, 8, 16, 32, 64]:
torch.cuda.synchronize()
ttime = []
errors = []
for i in range(10):
start = time.time()
eigenvectors, eigenvalues = batch_eigendecomp(
X, atol=atol, max_sweeps=max_sweeps,
clear_cache=args.clear_cache
)
torch.cuda.synchronize()
end = time.time()
ttime.append(end - start)
# Check error for first matrix of batch
reconstructed = eigenvectors[0].T @ torch.diag(eigenvalues[0]) @ eigenvectors[0]
error = torch.abs(X[0] - reconstructed).max().item()
errors.append(error)
stats.append((
batch_size, n, dtype, atol, max_sweeps,
np.mean(ttime), np.std(ttime), np.median(ttime),
np.mean(errors), np.max(errors)
))
print(f"==== Batch size: {batch_size}, Matrix size: {n}, dtype: {dtype}, atol: {atol}, max_sweeps: {max_sweeps} ====")
print(f"\tAverage time: {np.mean(ttime):.2f} seconds")
print(f"\tStd time: {np.std(ttime):.2f} seconds")
print(f"\tMedian time: {np.median(ttime):.2f} seconds")
print(f"\tAverage error: {np.mean(errors):.2e}")
print(f"\tMax error: {np.max(errors):.2e}")
# save csv
import pandas as pd
df = pd.DataFrame(stats, columns=[
'batch_size', 'n', 'dtype', 'atol', 'max_sweeps',
'mean_time', 'std_time', 'median_time',
'mean_error', 'max_error'
])
df.to_csv('batch_eigendecomp_stats.csv', index=False)
# plot and save. plot x : log atol, y : max_sweeps, z : mean_time (color, color gradient from min to max)
import matplotlib.pyplot as plt
plt.figure(figsize=(12, 8))
# Create a scatter plot for each batch_size and n combination
unique_batch_sizes = df['batch_size'].unique()
unique_ns = df['n'].unique()
for bs in unique_batch_sizes:
for n in unique_ns:
subset = df[(df['batch_size'] == bs) & (df['n'] == n)]
if len(subset) == 0:
continue
# Create scatter plot with color mapped to mean_time
scatter = plt.scatter(
np.log10(subset['atol']),
np.log2(subset['max_sweeps']),
c=np.log10(subset['mean_time']),
cmap='viridis',
s=100,
label=f'bs={bs}, n={n}'
)
plt.colorbar(label='log10(Mean Time (s))')
plt.xlabel('log10(atol)')
plt.ylabel('log2(max_sweeps)')
plt.title('Eigendecomposition Performance')
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
plt.tight_layout()
plt.savefig('eigendecomp_performance.png', dpi=300, bbox_inches='tight')
plt.close()
# plot error vs time. plot x : log mean error, y : log mean time, color : log2(max_sweeps)
plt.figure(figsize=(12, 8))
# Create scatter plot for each batch_size and n combination
for bs in unique_batch_sizes:
for n in unique_ns:
subset = df[(df['batch_size'] == bs) & (df['n'] == n)]
if len(subset) == 0:
continue
scatter = plt.scatter(
np.log10(subset['mean_error']),
np.log10(subset['mean_time']),
c=np.log2(subset['max_sweeps']),
cmap='viridis',
s=100,
label=f'bs={bs}, n={n}'
)
plt.colorbar(label='log2(max_sweeps)')
plt.xlabel('log10(Mean Error)')
plt.ylabel('log10(Mean Time (s))')
plt.title('Error vs Time Trade-off')
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
plt.tight_layout()
plt.savefig('error_vs_time.png', dpi=300, bbox_inches='tight')
plt.close()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment