Created
October 31, 2024 09:43
-
-
Save cloneofsimo/191bc3cd1f65a9f900f2dc8a6b486fe7 to your computer and use it in GitHub Desktop.
extended_syevjBatched torch
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# batch_eigendecomp.py | |
import torch | |
from torch.utils.cpp_extension import load_inline | |
import argparse | |
import os | |
import shutil | |
def clear_cuda_cache(): | |
cache_path = os.path.expanduser('~/.cache/torch_extensions') | |
if os.path.exists(cache_path): | |
shutil.rmtree(cache_path) | |
print(f"Cleared PyTorch extensions cache at {cache_path}") | |
cpp_source = """ | |
#include <torch/extension.h> | |
std::vector<torch::Tensor> forward(torch::Tensor input, double atol, int max_sweeps); | |
""" | |
cuda_source = """ | |
#include <torch/extension.h> | |
#include <cuda.h> | |
#include <cuda_runtime.h> | |
#include <cusolverDn.h> | |
#include <vector> | |
#define CHECK_CUDA(err) do { if (err != cudaSuccess) { \ | |
printf("CUDA error: %s at line %d in file %s\\n", \ | |
cudaGetErrorString(err), __LINE__, __FILE__); \ | |
throw std::runtime_error("CUDA error"); \ | |
}} while (0) | |
#define CHECK_CUSOLVER(err) do { if (err != CUSOLVER_STATUS_SUCCESS) { \ | |
printf("CUSOLVER error: %d at line %d in file %s\\n", \ | |
err, __LINE__, __FILE__); \ | |
throw std::runtime_error("CUSOLVER error"); \ | |
}} while (0) | |
template<typename scalar_t> | |
void eigendecomp_kernel_impl( | |
const torch::Tensor& input, | |
torch::Tensor& eigenvectors, | |
torch::Tensor& eigenvalues, | |
scalar_t atol, | |
int max_sweeps) { | |
const auto batch_size = input.size(0); | |
const auto n = input.size(1); | |
const auto lda = n; | |
cusolverDnHandle_t cusolverH = nullptr; | |
cudaStream_t stream = nullptr; | |
syevjInfo_t syevj_params = nullptr; | |
CHECK_CUSOLVER(cusolverDnCreate(&cusolverH)); | |
CHECK_CUDA(cudaStreamCreate(&stream)); | |
CHECK_CUSOLVER(cusolverDnSetStream(cusolverH, stream)); | |
CHECK_CUSOLVER(cusolverDnCreateSyevjInfo(&syevj_params)); | |
const int sort_eig = 0; // Don't sort eigenvalues | |
const cusolverEigMode_t jobz = CUSOLVER_EIG_MODE_VECTOR; | |
const cublasFillMode_t uplo = CUBLAS_FILL_MODE_LOWER; | |
CHECK_CUSOLVER(cusolverDnXsyevjSetTolerance(syevj_params, atol)); | |
CHECK_CUSOLVER(cusolverDnXsyevjSetMaxSweeps(syevj_params, max_sweeps)); | |
CHECK_CUSOLVER(cusolverDnXsyevjSetSortEig(syevj_params, sort_eig)); | |
int lwork = 0; | |
if (std::is_same<scalar_t, double>::value) { | |
CHECK_CUSOLVER(cusolverDnDsyevjBatched_bufferSize( | |
cusolverH, jobz, uplo, n, | |
reinterpret_cast<double*>(input.data_ptr<scalar_t>()), | |
lda, | |
reinterpret_cast<double*>(eigenvalues.data_ptr<scalar_t>()), | |
&lwork, | |
syevj_params, | |
batch_size | |
)); | |
} else { | |
CHECK_CUSOLVER(cusolverDnSsyevjBatched_bufferSize( | |
cusolverH, jobz, uplo, n, | |
reinterpret_cast<float*>(input.data_ptr<scalar_t>()), | |
lda, | |
reinterpret_cast<float*>(eigenvalues.data_ptr<scalar_t>()), | |
&lwork, | |
syevj_params, | |
batch_size | |
)); | |
} | |
auto workspace = torch::empty({lwork}, input.options()); | |
auto info = torch::empty({batch_size}, torch::dtype(torch::kInt32).device(input.device())); | |
eigenvectors.copy_(input); | |
if (std::is_same<scalar_t, double>::value) { | |
CHECK_CUSOLVER(cusolverDnDsyevjBatched( | |
cusolverH, jobz, uplo, n, | |
reinterpret_cast<double*>(eigenvectors.data_ptr<scalar_t>()), | |
lda, | |
reinterpret_cast<double*>(eigenvalues.data_ptr<scalar_t>()), | |
reinterpret_cast<double*>(workspace.data_ptr<scalar_t>()), | |
lwork, | |
reinterpret_cast<int*>(info.data_ptr<int>()), | |
syevj_params, | |
batch_size | |
)); | |
} else { | |
CHECK_CUSOLVER(cusolverDnSsyevjBatched( | |
cusolverH, jobz, uplo, n, | |
reinterpret_cast<float*>(eigenvectors.data_ptr<scalar_t>()), | |
lda, | |
reinterpret_cast<float*>(eigenvalues.data_ptr<scalar_t>()), | |
reinterpret_cast<float*>(workspace.data_ptr<scalar_t>()), | |
lwork, | |
reinterpret_cast<int*>(info.data_ptr<int>()), | |
syevj_params, | |
batch_size | |
)); | |
} | |
CHECK_CUSOLVER(cusolverDnDestroySyevjInfo(syevj_params)); | |
CHECK_CUSOLVER(cusolverDnDestroy(cusolverH)); | |
CHECK_CUDA(cudaStreamDestroy(stream)); | |
} | |
void eigendecomp_kernel( | |
const torch::Tensor& input, | |
torch::Tensor& eigenvectors, | |
torch::Tensor& eigenvalues, | |
double atol, | |
int max_sweeps) { | |
AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "eigendecomp_kernel", ([&] { | |
eigendecomp_kernel_impl<scalar_t>( | |
input, eigenvectors, eigenvalues, | |
static_cast<scalar_t>(atol), max_sweeps); | |
})); | |
} | |
std::vector<torch::Tensor> forward(torch::Tensor input, double atol, int max_sweeps) { | |
TORCH_CHECK(input.device().is_cuda(), "input must be a CUDA tensor"); | |
TORCH_CHECK(input.is_contiguous(), "input must be contiguous"); | |
TORCH_CHECK(input.dim() == 3, "Input tensor must be 3D (batch_size, n, n)"); | |
TORCH_CHECK(input.size(1) == input.size(2), "Input tensor must be square matrices"); | |
TORCH_CHECK(input.scalar_type() == torch::kFloat32 || input.scalar_type() == torch::kFloat64, | |
"Input must be float32 or float64"); | |
auto batch_size = input.size(0); | |
auto n = input.size(1); | |
auto eigenvectors = torch::empty_like(input); | |
auto eigenvalues = torch::empty({batch_size, n}, input.options()); | |
eigendecomp_kernel(input, eigenvectors, eigenvalues, atol, max_sweeps); | |
return {eigenvectors, eigenvalues}; | |
} | |
""" | |
def init_extension(clear_cache=False): | |
if clear_cache: | |
clear_cuda_cache() | |
return load_inline( | |
name="batch_eigendecomp", | |
cpp_sources=cpp_source, | |
cuda_sources=cuda_source, | |
functions=["forward"], | |
extra_cuda_cflags=["-O3"], | |
extra_cflags=["-O3"], | |
extra_ldflags=["-lcusolver"], | |
verbose=True, | |
build_directory=None | |
) | |
batch_eigendecomp_cuda = None | |
def get_extension(clear_cache=False): | |
global batch_eigendecomp_cuda | |
if batch_eigendecomp_cuda is None or clear_cache: | |
batch_eigendecomp_cuda = init_extension(clear_cache) | |
return batch_eigendecomp_cuda | |
class BatchEigendecomp(torch.autograd.Function): | |
@staticmethod | |
def forward(ctx, input, atol=1e-7, max_sweeps=100): | |
if not input.is_contiguous(): | |
input = input.contiguous() | |
return tuple(get_extension().forward(input, atol, max_sweeps)) | |
def batch_eigendecomp(x, atol=1e-7, max_sweeps=100, clear_cache=False): | |
""" | |
Computes eigendecomposition for a batch of symmetric matrices. | |
Args: | |
x: Input tensor of shape [batch_size, n, n] | |
Must be symmetric, CUDA tensor in float32 or float64 dtype | |
atol: Absolute tolerance for convergence (default: 1e-7) | |
max_sweeps: Maximum number of sweeps for the Jacobi algorithm (default: 100) | |
clear_cache: If True, clears the PyTorch extensions cache before compiling | |
Returns: | |
tuple: (eigenvectors, eigenvalues) | |
- eigenvectors: tensor of shape [batch_size, n, n] | |
- eigenvalues: tensor of shape [batch_size, n] | |
""" | |
if clear_cache: | |
get_extension(clear_cache=True) | |
return BatchEigendecomp.apply(x, atol, max_sweeps) | |
import time | |
import numpy as np | |
import pandas as pd | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--clear-cache', action='store_true', help='Clear PyTorch extensions cache') | |
parser.add_argument('--atol', type=float, default=1e-6, help='Absolute tolerance') | |
parser.add_argument('--max-sweeps', type=int, default=10, help='Maximum number of sweeps') | |
parser.add_argument('--dtype', choices=['float32', 'float64'], default='float32', help='Data type') | |
args = parser.parse_args() | |
if args.clear_cache: | |
clear_cuda_cache() | |
dtype = torch.float32 if args.dtype == 'float32' else torch.float64 | |
stats = [] | |
for logbs in [6]: | |
for logmatsize in [9]: | |
batch_size = 2**logbs | |
n = 2**logmatsize | |
X = torch.randn(batch_size, n, n, dtype=dtype, device='cuda:0') | |
X = X + X.transpose(-2, -1) # Make symmetric | |
for atol in [1e-1, 1e-3, 1e-5, 1e-7, 1e-9, 1e-11]: | |
for max_sweeps in [2, 4, 8, 16, 32, 64]: | |
torch.cuda.synchronize() | |
ttime = [] | |
errors = [] | |
for i in range(10): | |
start = time.time() | |
eigenvectors, eigenvalues = batch_eigendecomp( | |
X, atol=atol, max_sweeps=max_sweeps, | |
clear_cache=args.clear_cache | |
) | |
torch.cuda.synchronize() | |
end = time.time() | |
ttime.append(end - start) | |
# Check error for first matrix of batch | |
reconstructed = eigenvectors[0].T @ torch.diag(eigenvalues[0]) @ eigenvectors[0] | |
error = torch.abs(X[0] - reconstructed).max().item() | |
errors.append(error) | |
stats.append(( | |
batch_size, n, dtype, atol, max_sweeps, | |
np.mean(ttime), np.std(ttime), np.median(ttime), | |
np.mean(errors), np.max(errors) | |
)) | |
print(f"==== Batch size: {batch_size}, Matrix size: {n}, dtype: {dtype}, atol: {atol}, max_sweeps: {max_sweeps} ====") | |
print(f"\tAverage time: {np.mean(ttime):.2f} seconds") | |
print(f"\tStd time: {np.std(ttime):.2f} seconds") | |
print(f"\tMedian time: {np.median(ttime):.2f} seconds") | |
print(f"\tAverage error: {np.mean(errors):.2e}") | |
print(f"\tMax error: {np.max(errors):.2e}") | |
# save csv | |
import pandas as pd | |
df = pd.DataFrame(stats, columns=[ | |
'batch_size', 'n', 'dtype', 'atol', 'max_sweeps', | |
'mean_time', 'std_time', 'median_time', | |
'mean_error', 'max_error' | |
]) | |
df.to_csv('batch_eigendecomp_stats.csv', index=False) | |
# plot and save. plot x : log atol, y : max_sweeps, z : mean_time (color, color gradient from min to max) | |
import matplotlib.pyplot as plt | |
plt.figure(figsize=(12, 8)) | |
# Create a scatter plot for each batch_size and n combination | |
unique_batch_sizes = df['batch_size'].unique() | |
unique_ns = df['n'].unique() | |
for bs in unique_batch_sizes: | |
for n in unique_ns: | |
subset = df[(df['batch_size'] == bs) & (df['n'] == n)] | |
if len(subset) == 0: | |
continue | |
# Create scatter plot with color mapped to mean_time | |
scatter = plt.scatter( | |
np.log10(subset['atol']), | |
np.log2(subset['max_sweeps']), | |
c=np.log10(subset['mean_time']), | |
cmap='viridis', | |
s=100, | |
label=f'bs={bs}, n={n}' | |
) | |
plt.colorbar(label='log10(Mean Time (s))') | |
plt.xlabel('log10(atol)') | |
plt.ylabel('log2(max_sweeps)') | |
plt.title('Eigendecomposition Performance') | |
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left') | |
plt.tight_layout() | |
plt.savefig('eigendecomp_performance.png', dpi=300, bbox_inches='tight') | |
plt.close() | |
# plot error vs time. plot x : log mean error, y : log mean time, color : log2(max_sweeps) | |
plt.figure(figsize=(12, 8)) | |
# Create scatter plot for each batch_size and n combination | |
for bs in unique_batch_sizes: | |
for n in unique_ns: | |
subset = df[(df['batch_size'] == bs) & (df['n'] == n)] | |
if len(subset) == 0: | |
continue | |
scatter = plt.scatter( | |
np.log10(subset['mean_error']), | |
np.log10(subset['mean_time']), | |
c=np.log2(subset['max_sweeps']), | |
cmap='viridis', | |
s=100, | |
label=f'bs={bs}, n={n}' | |
) | |
plt.colorbar(label='log2(max_sweeps)') | |
plt.xlabel('log10(Mean Error)') | |
plt.ylabel('log10(Mean Time (s))') | |
plt.title('Error vs Time Trade-off') | |
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left') | |
plt.tight_layout() | |
plt.savefig('error_vs_time.png', dpi=300, bbox_inches='tight') | |
plt.close() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment