Created
February 17, 2020 11:52
-
-
Save AsharFatmi/fea96d663c95ad08c4d8d873b18a0fed to your computer and use it in GitHub Desktop.
convert onnx to trt engine
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
# Copyright 2019 NVIDIA Corporation | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import os | |
import sys | |
import glob | |
import math | |
import logging | |
import argparse | |
import tensorrt as trt | |
from ImagenetCalibrator import ImagenetCalibrator, get_calibration_files, get_int8_calibrator # local module | |
TRT_LOGGER = trt.Logger() | |
logging.basicConfig(level=logging.DEBUG, | |
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", | |
datefmt="%Y-%m-%d %H:%M:%S") | |
logger = logging.getLogger(__name__) | |
def add_profiles(config, inputs, opt_profiles): | |
logger.debug("=== Optimization Profiles ===") | |
for i, profile in enumerate(opt_profiles): | |
for inp in inputs: | |
_min, _opt, _max = profile.get_shape(inp.name) | |
logger.debug("{} - OptProfile {} - Min {} Opt {} Max {}".format(inp.name, i, _min, _opt, _max)) | |
config.add_optimization_profile(profile) | |
def mark_outputs(network): | |
# Mark last layer's outputs if not already marked | |
# NOTE: This may not be correct in all cases | |
last_layer = network.get_layer(network.num_layers-1) | |
if not last_layer.num_outputs: | |
logger.error("Last layer contains no outputs.") | |
return | |
for i in range(last_layer.num_outputs): | |
network.mark_output(last_layer.get_output(i)) | |
def check_network(network): | |
if not network.num_outputs: | |
logger.warning("No output nodes found, marking last layer's outputs as network outputs. Correct this if wrong.") | |
mark_outputs(network) | |
inputs = [network.get_input(i) for i in range(network.num_inputs)] | |
outputs = [network.get_output(i) for i in range(network.num_outputs)] | |
max_len = max([len(inp.name) for inp in inputs] + [len(out.name) for out in outputs]) | |
logger.debug("=== Network Description ===") | |
for i, inp in enumerate(inputs): | |
logger.debug("Input {0} | Name: {1:{2}} | Shape: {3}".format(i, inp.name, max_len, inp.shape)) | |
for i, out in enumerate(outputs): | |
logger.debug("Output {0} | Name: {1:{2}} | Shape: {3}".format(i, out.name, max_len, out.shape)) | |
def get_batch_sizes(max_batch_size): | |
# Returns powers of 2, up to and including max_batch_size | |
max_exponent = math.log2(max_batch_size) | |
for i in range(int(max_exponent)+1): | |
batch_size = 2**i | |
yield batch_size | |
if max_batch_size != batch_size: | |
yield max_batch_size | |
# TODO: This only covers dynamic shape for batch size, not dynamic shape for other dimensions | |
def create_optimization_profiles(builder, inputs, batch_sizes=[1,8,16,32,64]): | |
# Check if all inputs are fixed explicit batch to create a single profile and avoid duplicates | |
if all([inp.shape[0] > -1 for inp in inputs]): | |
profile = builder.create_optimization_profile() | |
for inp in inputs: | |
fbs, shape = inp.shape[0], inp.shape[1:] | |
profile.set_shape(inp.name, min=(fbs, *shape), opt=(fbs, *shape), max=(fbs, *shape)) | |
return [profile] | |
# Otherwise for mixed fixed+dynamic explicit batch inputs, create several profiles | |
profiles = {} | |
for bs in batch_sizes: | |
if not profiles.get(bs): | |
profiles[bs] = builder.create_optimization_profile() | |
for inp in inputs: | |
shape = inp.shape[1:] | |
# Check if fixed explicit batch | |
if inp.shape[0] > -1: | |
bs = inp.shape[0] | |
profiles[bs].set_shape(inp.name, min=(bs, *shape), opt=(bs, *shape), max=(bs, *shape)) | |
return list(profiles.values()) | |
def main(): | |
parser = argparse.ArgumentParser(description="Creates a TensorRT engine from the provided ONNX file.\n") | |
parser.add_argument("--onnx", required=True, help="The ONNX model file to convert to TensorRT") | |
parser.add_argument("-o", "--output", type=str, default="model.engine", help="The path at which to write the engine") | |
parser.add_argument("-b", "--max-batch-size", type=int, default=32, help="The max batch size for the TensorRT engine input") | |
parser.add_argument("-v", "--verbosity", action="count", help="Verbosity for logging. (None) for ERROR, (-v) for INFO/WARNING/ERROR, (-vv) for VERBOSE.") | |
parser.add_argument("--explicit-batch", action='store_true', help="Set trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH.") | |
parser.add_argument("--explicit-precision", action='store_true', help="Set trt.NetworkDefinitionCreationFlag.EXPLICIT_PRECISION.") | |
parser.add_argument("--gpu-fallback", action='store_true', help="Set trt.BuilderFlag.GPU_FALLBACK.") | |
parser.add_argument("--refittable", action='store_true', help="Set trt.BuilderFlag.REFIT.") | |
parser.add_argument("--debug", action='store_true', help="Set trt.BuilderFlag.DEBUG.") | |
parser.add_argument("--strict-types", action='store_true', help="Set trt.BuilderFlag.STRICT_TYPES.") | |
parser.add_argument("--fp16", action="store_true", help="Attempt to use FP16 kernels when possible.") | |
parser.add_argument("--int8", action="store_true", help="Attempt to use INT8 kernels when possible. This should generally be used in addition to the --fp16 flag. \ | |
ONLY SUPPORTS RESNET-LIKE MODELS SUCH AS RESNET50/VGG16/INCEPTION/etc.") | |
parser.add_argument("--calibration-cache", help="(INT8 ONLY) The path to read/write from calibration cache.", default="calibration.cache") | |
parser.add_argument("--calibration-data", help="(INT8 ONLY) The directory containing {*.jpg, *.jpeg, *.png} files to use for calibration. (ex: Imagenet Validation Set)", default=None) | |
parser.add_argument("--calibration-batch-size", help="(INT8 ONLY) The batch size to use during calibration.", type=int, default=32) | |
parser.add_argument("--max-calibration-size", help="(INT8 ONLY) The max number of data to calibrate on from --calibration-data.", type=int, default=512) | |
parser.add_argument("-p", "--preprocess_func", type=str, default=None, help="(INT8 ONLY) Function defined in 'processing.py' to use for pre-processing calibration data.") | |
args, _ = parser.parse_known_args() | |
# Adjust logging verbosity | |
if args.verbosity is None: | |
TRT_LOGGER.min_severity = trt.Logger.Severity.ERROR | |
# -v | |
elif args.verbosity == 1: | |
TRT_LOGGER.min_severity = trt.Logger.Severity.INFO | |
# -vv | |
else: | |
TRT_LOGGER.min_severity = trt.Logger.Severity.VERBOSE | |
logger.info("TRT_LOGGER Verbosity: {:}".format(TRT_LOGGER.min_severity)) | |
# Network flags | |
network_flags = 0 | |
if args.explicit_batch: | |
network_flags |= 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) | |
if args.explicit_precision: | |
network_flags |= 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_PRECISION) | |
builder_flag_map = { | |
'gpu_fallback': trt.BuilderFlag.GPU_FALLBACK, | |
'refittable': trt.BuilderFlag.REFIT, | |
'debug': trt.BuilderFlag.DEBUG, | |
'strict_types': trt.BuilderFlag.STRICT_TYPES, | |
'fp16': trt.BuilderFlag.FP16, | |
'int8': trt.BuilderFlag.INT8, | |
} | |
# Building engine | |
with trt.Builder(TRT_LOGGER) as builder, \ | |
builder.create_network(network_flags) as network, \ | |
builder.create_builder_config() as config, \ | |
trt.OnnxParser(network, TRT_LOGGER) as parser: | |
config.max_workspace_size = 2**30 # 1GiB | |
# Set Builder Config Flags | |
for flag in builder_flag_map: | |
if getattr(args, flag): | |
logger.info("Setting {}".format(builder_flag_map[flag])) | |
config.set_flag(builder_flag_map[flag]) | |
if args.fp16 and not builder.platform_has_fast_fp16: | |
logger.warning("FP16 not supported on this platform.") | |
if args.int8 and not builder.platform_has_fast_int8: | |
logger.warning("INT8 not supported on this platform.") | |
if args.int8: | |
config.int8_calibrator = get_int8_calibrator(args.calibration_cache, | |
args.calibration_data, | |
args.max_calibration_size, | |
args.preprocess_func, | |
args.calibration_batch_size) | |
# Fill network atrributes with information by parsing model | |
with open(args.onnx, "rb") as f: | |
if not parser.parse(f.read()): | |
print('ERROR: Failed to parse the ONNX file: {}'.format(args.onnx)) | |
for error in range(parser.num_errors): | |
print(parser.get_error(error)) | |
sys.exit(1) | |
# Display network info and check certain properties | |
check_network(network) | |
if args.explicit_batch: | |
# Add optimization profiles | |
batch_sizes = [1, 8, 16, 32, 64] | |
inputs = [network.get_input(i) for i in range(network.num_inputs)] | |
opt_profiles = create_optimization_profiles(builder, inputs, batch_sizes) | |
add_profiles(config, inputs, opt_profiles) | |
# Implicit Batch Network | |
else: | |
builder.max_batch_size = args.max_batch_size | |
logger.info("Building Engine...") | |
with builder.build_engine(network, config) as engine, open(args.output, "wb") as f: | |
logger.info("Serializing engine to file: {:}".format(args.output)) | |
f.write(engine.serialize()) | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment