Skip to content

Instantly share code, notes, and snippets.

@kiyoon
Created December 16, 2023 04:00
Show Gist options
  • Save kiyoon/5c79f4cf79b93c3f84c06ed40ec5eb5f to your computer and use it in GitHub Desktop.
Save kiyoon/5c79f4cf79b93c3f84c06ed40ec5eb5f to your computer and use it in GitHub Desktop.
import tensorrt as trt
import torch
import tqdm
logger = trt.Logger(trt.Logger.WARNING)
runtime = trt.Runtime(logger)
with open("vae_encoder_engine.trt", "rb") as f:
serialized_engine = f.read()
# load libnvinfer so we can deserialize engine
trt.init_libnvinfer_plugins(None, "")
engine = runtime.deserialize_cuda_engine(serialized_engine)
context = engine.create_execution_context()
# Allocate memory for inputs and outputs
for index in range(engine.num_io_tensors):
name = engine.get_tensor_name(index)
print(f"{name} {engine.get_tensor_dtype(name)} {engine.get_tensor_shape(name)}")
input_name = "input"
output_name = "output"
input_shape = (1, 3, 1024, 1024)
output_shape = (1, 8, 128, 128)
input_buf = torch.randn(input_shape, dtype=torch.float32).cuda()
output_buf = torch.empty(output_shape, dtype=torch.float32).cuda()
context.set_tensor_address(input_name, input_buf.data_ptr())
context.set_tensor_address(output_name, output_buf.data_ptr())
stream = torch.cuda.Stream()
for i in tqdm.tqdm(range(1000)):
context.execute_async_v3(stream_handle=stream.cuda_stream)
stream.synchronize()
print(output_buf.shape)
print(output_buf)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment