Skip to content

Instantly share code, notes, and snippets.

View rmccorm4's full-sized avatar
💭
I may be slow to respond

Ryan McCormick rmccorm4

💭
I may be slow to respond
View GitHub Profile
def setup_binding_shapes(
engine: trt.ICudaEngine,
context: trt.IExecutionContext,
host_inputs: List[np.ndarray],
input_binding_idxs: List[int],
output_binding_idxs: List[int],
):
# Explicitly set the dynamic input shapes, so the dynamic output
# shapes can be computed internally
for host_input, binding_index in zip(host_inputs, input_binding_idxs):
# Load serialized engine file into memory
with open("alexnet_dynamic.engine", "rb") as f, trt.Runtime(TRT_LOGGER) as runtime:
engine = runtime.deserialize_cuda_engine(f.read())
# Create context, this can be re-used
context = engine.create_execution_context()
# Profile 0 (first profile) is used by default
context.active_optimization_profile = 0
# These binding_idxs can change if either the context or the
# Export sample Alexnet model to ONNX with a dynamic batch dimension
wget https://gist.githubusercontent.com/rmccorm4/b72abac18aed6be4c1725db18eba4930/raw/3919c883b97a231877b454dae695fe074a1acdff/alexnet_onnx.py
python3 alexnet_onnx.py
# Emulate "maxBatchSize" behavior from implicit batch engines by setting
# an optimization profile with min=(1, *shape), opt=max=(maxBatchSize, *shape)
MAX_BATCH_SIZE=32
INPUT_NAME="actual_input_1"
# Convert dynamic batch ONNX model to TRT Engine with optimization profile defined
# Create multiple optimization profiles for different contexts to use
shape0 = (1, 3, 224, 224)
profile0 = builder.create_optimization_profile()
profile0.set_shape("input", min=shape0, opt=shape0, max=shape0)
config.add_optimization_profile(profile0)
shape1 = (1, 3, 448, 448)
profile1 = builder.create_optimization_profile()
profile1.set_shape("input", min=shape1, opt=shape1, max=shape1)
config.add_optimization_profile(profile1)
# Previously: builder.fp16 = True
config.set_flag(trt.BuilderFlag.FP16)
# Previously: builder.int8 = True
config.set_flag(trt.BuilderFlag.INT8)
# Previously: builder.int8_calibrator = MyCustomCalibrator()
config.int8_calibrator = MyCustomCalibrator()
# ...
# Previously: builder.build_cuda_engine(network)