Skip to content

Instantly share code, notes, and snippets.

@t-vi
Last active November 12, 2022 12:15
Show Gist options
  • Save t-vi/2f4fe23a5b473b9dceb95b163378b4d5 to your computer and use it in GitHub Desktop.
Save t-vi/2f4fe23a5b473b9dceb95b163378b4d5 to your computer and use it in GitHub Desktop.
from timeit import default_timer as time
import numpy as np
from numba import cuda
import os
os.environ['NUMBAPRO_LIBDEVICE']='/usr/lib/nvidia-cuda-toolkit/libdevice/'
os.environ['NUMBAPRO_NVVM']='/usr/lib/x86_64-linux-gnu/libnvvm.so.3.1.0'
import numpy
import torch
import ctypes
import math
from torch.autograd import Variable
@cuda.jit('(float32[:,:], float32[:,:], float32[:,:], float32[:,:], float32[:,:], int32, int32, int32)')
def cu_exp_matrix_mul(A, c, d, u, v, b, n, m):
tx = cuda.threadIdx.x
ty = cuda.threadIdx.y
bx = cuda.blockIdx.x
by = cuda.blockIdx.y
bw = cuda.blockDim.x
bh = cuda.blockDim.y
bi = tx + bx * bw
ni = ty + by * bh
if ni >= n or bi >= b:
return
r = 0
for mi in range(m):
r += math.exp(-A[ni, mi]+c[bi,mi]+d[bi,ni]) *u[bi, mi]
v[bi, ni] = r
@cuda.jit('(float32[:,:], float32[:,:], float32[:,:], float32[:], int32, int32, int32)')
def cu_exp_matrix_cost_sum(A, c, d, v, b, n, m):
tx = cuda.threadIdx.x
bx = cuda.blockIdx.x
bw = cuda.blockDim.x
bi = tx + bx * bw
if bi >= b:
return
r = 0
for mi in range(m):
for ni in range(n):
r += math.exp(-A[ni, mi]+c[bi,mi]+d[bi,ni])*A[ni, mi]
v[bi] = r
def get_devicendarray(t):
assert t.type() == 'torch.cuda.FloatTensor'
ctx = cuda.cudadrv.driver.driver.get_context()
mp = cuda.cudadrv.driver.MemoryPointer(ctx, ctypes.c_ulong(t.data_ptr()), t.numel()*4)
return cuda.cudadrv.devicearray.DeviceNDArray(t.size(), [i*4 for i in t.stride()], numpy.dtype('float32'),
gpu_data=mp, stream=torch.cuda.current_stream().cuda_stream)
def batch_expmat_product(A, c, d, u):
BLOCK=32
b = c.size(0)
n = A.size(0)
m = A.size(1)
assert A.dim()==2 and c.dim()==2 and d.dim()==2 and u.dim()==2, "dimension mismatch"
assert c.size(1)==m and d.size(0)==b and d.size(1)==n and u.size(0)==b and u.size(1)==m, "size mismatch"
v = u.new(d.size()).zero_()
Ad,cd,dd,ud,vd = (get_devicendarray(x) for x in (A,c,d,u,v))
cu_exp_matrix_mul[((b-1)//BLOCK+1,(m-1)//BLOCK+1),(BLOCK,BLOCK)](Ad,cd,dd,ud,vd,b,n,m)
return v
def batch_expmat_mat_sum(A, c, d):
BLOCK=128
b = c.size(0)
n = A.size(0)
m = A.size(1)
assert A.dim()==2 and c.dim()==2 and d.dim()==2, "dimension mismatch"
assert c.size(1)==m and d.size(0)==b and d.size(1)==n, "size mismatch"
res = u.new(b).zero_()
Ad,cd,dd,resd = (get_devicendarray(x) for x in (A,c,d,res))
cu_exp_matrix_cost_sum[((b-1)//BLOCK+1),(BLOCK)](Ad,cd,dd,resd,b,n,m)
return res
b,n,m = 100,200,300
A = torch.randn(n,m).cuda()
c = torch.randn(b,m).cuda()
d = torch.randn(b,n).cuda()
u = torch.randn(b,m).cuda()
t = torch.randn(b,n).cuda()
w = batch_expmat_product(A,c,d,u)
@netw0rkf10w
Copy link

@t-vi Thank you for sharing your code. Could you please add some comments to tell what each function does? I am trying to write some custom PyTorch function in Numba, and I feel that your code is helpful, but unfortunately it is not easy to understand. Thank you so much!

@grinisrit
Copy link

To get the context for a tensor t I think it is better to use:

ctx = cuda.cudadrv.devices.get_context(t.device.index)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment