Created
October 17, 2022 09:50
-
-
Save wkcn/232d2cf8d50e15cdb38be3e577cc4e3a to your computer and use it in GitHub Desktop.
FP8GEMM
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import transformer_engine.pytorch.cpp_extensions as texcpp | |
from transformer_engine.pytorch.module import get_workspace | |
import transformer_engine_extensions as tex | |
scale = 1.0 | |
meta = tex.FP8TensorMeta() | |
meta.scale = torch.ones(1,dtype=torch.float32, device="cuda") * scale | |
meta.scale_inv = torch.ones(1, dtype=torch.float32, device="cuda") / scale | |
meta.amax_history = torch.zeros(1, 1, dtype=torch.float32, device="cuda") | |
def cast_to_fp8(x, qtype): | |
ret = texcpp.cast_to_fp8(x, meta, tex.FP8FwdTensors.GEMM1_INPUT, qtype) | |
ret._fp8_qtype = qtype | |
return ret | |
def cast_from_fp8(x, qtype): | |
ret = texcpp.cast_from_fp8(x, meta, tex.FP8FwdTensors.GEMM1_INPUT, x._fp8_qtype, qtype) | |
ret._fp8_qtype = qtype | |
return ret | |
one_scale_inv = torch.ones(1, dtype=torch.float32, device="cuda") | |
empty_tensor = torch.Tensor() | |
workspace = get_workspace() | |
assert workspace.is_cuda | |
PT_DType = dict([(v, k) for k, v in texcpp.TE_DType.items()]) | |
PT_DType[tex.DType.kFloat8E4M3] = torch.uint8 | |
PT_DType[tex.DType.kFloat8E5M2] = torch.uint8 | |
def fp8_gemm(fa, fb, trans_a, trans_b, bias=None, qtype=tex.DType.kFloat32): | |
''' | |
# te_gemm | |
input_A: (A_row, A_col) | |
input_B: (B_row, B_col) | |
when transa, transb = True, False | |
m, k, n = A_row, A_col, B_row | |
lda, ldb, ldd = A_col, A_col, A_row | |
output_D: (B_row, A_row) | |
when transa, transb = False, False | |
m, k, n = A_col, A_row, B_row | |
lda, ldb, ldd = A_col, A_row, A_col | |
output_D: (B_row, A_col) | |
when transa, transb = False, True | |
m, k, n = A_col, A_row, B_col | |
lda, ldb, ldd = A_col, B_col, A_col | |
output_D: (B_col, A_col) | |
''' | |
assert fa.is_cuda and fb.is_cuda | |
assert fa.is_contiguous() | |
assert fb.is_contiguous() | |
device = fa.device | |
fa_qtype, fb_qtype = fa._fp8_qtype, fb._fp8_qtype | |
A_row, A_col = fa.shape | |
B_row, B_col = fb.shape | |
if trans_a and not trans_b: | |
assert A_col == B_col | |
C_row, C_col = B_row, A_row | |
elif not trans_a and not trans_b: | |
assert A_row == B_col | |
C_row, C_col = B_row, A_col | |
elif not trans_a and trans_b: | |
assert A_row == B_row | |
C_row, C_col = B_col, A_col | |
out_shape = (C_row, C_col) | |
dtype = PT_DType[qtype] | |
out = torch.empty(out_shape, dtype=dtype, device=device) | |
# te_gemm is column-order. | |
tex.te_gemm( | |
fa, one_scale_inv, fa_qtype, trans_a, | |
fb, one_scale_inv, fb_qtype, trans_b, | |
out, qtype, | |
bias or empty_tensor, empty_tensor, False, | |
workspace, workspace.shape[0], | |
False, True, | |
) | |
out._fp8_qtype = qtype | |
return out | |
def fp8_matmul(fa, fb, bias=None, qtype=tex.DType.kFloat32): | |
# trans_a = False and trans_b = False is not implemented. | |
fb_qtype = fb._fp8_qtype | |
fb = fb.T.contiguous() | |
fb._fp8_qtype = fb_qtype | |
return fp8_gemm(fb, fa, trans_a=True, trans_b=False, bias=bias, qtype=qtype) | |
if __name__ == '__main__': | |
a = torch.randn(128, 128).cuda() | |
b = torch.randn(128, 128).cuda() | |
fa = cast_to_fp8(a, tex.DType.kFloat8E4M3) | |
fb = cast_to_fp8(b, tex.DType.kFloat8E4M3) | |
qa = cast_from_fp8(fa, tex.DType.kFloat32) | |
qb = cast_from_fp8(fb, tex.DType.kFloat32) | |
qc = torch.matmul(qa, qb) | |
qc2 = fp8_matmul(fa, fb, qtype=tex.DType.kFloat16) | |
# E4M3/E5M2 @ E4M3/E5M2 = FP16/FP32 | |
print(qc, qc2) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment