Created
July 16, 2019 09:55
-
-
Save RomanSteinberg/63ed16aa2e2e4e19c1ad84fdc2b1f551 to your computer and use it in GitHub Desktop.
PyTorch -> TensorRT
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorrt as trt | |
import os | |
import torch | |
import onnx | |
TRT_LOGGER = trt.Logger(trt.Logger.WARNING) | |
def convert_to_trt(image_width, image_height): | |
onnx_file_path = 'model.onnx' | |
if not os.path.exists(onnx_file_path): | |
convert_to_onnx(image_width, image_height) | |
engine_file_path = 'model.trt' | |
"""Takes an ONNX file and creates a TensorRT engine to run inference with""" | |
with trt.Builder(TRT_LOGGER) as builder, builder.create_network() as network, trt.OnnxParser(network, | |
TRT_LOGGER) as parser: | |
builder.max_workspace_size = 1 << 30 # 1GB | |
builder.max_batch_size = 1 | |
# Parse model file | |
if not os.path.exists(onnx_file_path): | |
print('ONNX file {} not found.'.format(onnx_file_path)) | |
exit(0) | |
print('Loading ONNX file from path {}...'.format(onnx_file_path)) | |
with open(onnx_file_path, 'rb') as model: | |
print('Beginning ONNX file parsing') | |
parser.parse(model.read()) | |
print('Completed parsing of ONNX file') | |
print('Building an engine from file {}; this may take a while...'.format(onnx_file_path)) | |
engine = builder.build_cuda_engine(network) | |
print("Completed creating Engine") | |
with open(engine_file_path, "wb") as f: | |
f.write(engine.serialize()) | |
def convert_to_onnx(image_width, image_height): | |
model = PytorchModel() | |
model.load_state_dict(torch.load('model.pth')) | |
model.eval() | |
model.to('cpu') | |
onnx_file_path = 'model.onnx' | |
dummy_input = torch.randn(1, 3, image_height, image_width) | |
torch.onnx.export(model, dummy_input, onnx_file_path, verbose=True) | |
model = onnx.load(onnx_file_path) | |
# Check that the IR is well formed | |
onnx.checker.check_model(model) | |
# Print a human readable representation of the graph | |
onnx.helper.printable_graph(model.graph) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment