- 
      
- 
        Save rmccorm4/dabccb1f31dbdcf1019a4df431067e52 to your computer and use it in GitHub Desktop. 
| #!/usr/bin/env python3 | |
| import argparse | |
| from typing import Tuple, List | |
| import numpy as np | |
| import pycuda.driver as cuda | |
| import pycuda.autoinit | |
| import tensorrt as trt | |
| TRT_LOGGER = trt.Logger(trt.Logger.WARNING) | |
| def is_fixed(shape: Tuple[int]): | |
| return not is_dynamic(shape) | |
| def is_dynamic(shape: Tuple[int]): | |
| return any(dim is None or dim < 0 for dim in shape) | |
| def setup_binding_shapes( | |
| engine: trt.ICudaEngine, | |
| context: trt.IExecutionContext, | |
| host_inputs: List[np.ndarray], | |
| input_binding_idxs: List[int], | |
| output_binding_idxs: List[int], | |
| ): | |
| # Explicitly set the dynamic input shapes, so the dynamic output | |
| # shapes can be computed internally | |
| for host_input, binding_index in zip(host_inputs, input_binding_idxs): | |
| context.set_binding_shape(binding_index, host_input.shape) | |
| assert context.all_binding_shapes_specified | |
| host_outputs = [] | |
| device_outputs = [] | |
| for binding_index in output_binding_idxs: | |
| output_shape = context.get_binding_shape(binding_index) | |
| # Allocate buffers to hold output results after copying back to host | |
| buffer = np.empty(output_shape, dtype=np.float32) | |
| host_outputs.append(buffer) | |
| # Allocate output buffers on device | |
| device_outputs.append(cuda.mem_alloc(buffer.nbytes)) | |
| return host_outputs, device_outputs | |
| def get_binding_idxs(engine: trt.ICudaEngine, profile_index: int): | |
| # Calculate start/end binding indices for current context's profile | |
| num_bindings_per_profile = engine.num_bindings // engine.num_optimization_profiles | |
| start_binding = profile_index * num_bindings_per_profile | |
| end_binding = start_binding + num_bindings_per_profile | |
| # Separate input and output binding indices for convenience | |
| input_binding_idxs = [] | |
| output_binding_idxs = [] | |
| for binding_index in range(start_binding, end_binding): | |
| if engine.binding_is_input(binding_index): | |
| input_binding_idxs.append(binding_index) | |
| else: | |
| output_binding_idxs.append(binding_index) | |
| return input_binding_idxs, output_binding_idxs | |
| def load_engine(filename: str): | |
| # Load serialized engine file into memory | |
| with open(filename, "rb") as f, trt.Runtime(TRT_LOGGER) as runtime: | |
| return runtime.deserialize_cuda_engine(f.read()) | |
| def get_random_inputs( | |
| engine: trt.ICudaEngine, | |
| context: trt.IExecutionContext, | |
| input_binding_idxs: List[int], | |
| ): | |
| # Input data for inference | |
| host_inputs = [] | |
| for binding_index in input_binding_idxs: | |
| # If input shape is fixed, we'll just use it | |
| input_shape = context.get_binding_shape(binding_index) | |
| # If input shape is dynamic, we'll arbitrarily select one of the | |
| # the min/opt/max shapes from our optimization profile | |
| if is_dynamic(input_shape): | |
| profile_index = context.active_optimization_profile | |
| profile_shapes = engine.get_profile_shape(profile_index, binding_index) | |
| # 0=min, 1=opt, 2=max, or choose any shape, (min <= shape <= max) | |
| input_shape = profile_shapes[1] | |
| host_inputs.append(np.random.random(input_shape).astype(np.float32)) | |
| return host_inputs | |
| def main(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( | |
| "-e", "--engine", required=True, type=str, help="Path to TensorRT engine file." | |
| ) | |
| args = parser.parse_args() | |
| # Load a serialized engine into memory | |
| engine = load_engine(args.engine) | |
| # Create context, this can be re-used | |
| context = engine.create_execution_context() | |
| # Profile 0 (first profile) is used by default | |
| context.active_optimization_profile = 0 | |
| # These binding_idxs can change if either the context or the | |
| # active_optimization_profile are changed | |
| input_binding_idxs, output_binding_idxs = get_binding_idxs( | |
| engine, context.active_optimization_profile | |
| ) | |
| # Generate random inputs based on profile shapes | |
| host_inputs = get_random_inputs(engine, context, input_binding_idxs) | |
| print("Input Shapes: {}".format([inp.shape for inp in host_inputs])) | |
| # Allocate device memory for inputs. This can be easily re-used if the | |
| # input shapes don't change | |
| device_inputs = [cuda.mem_alloc(h_input.nbytes) for h_input in host_inputs] | |
| # Copy host inputs to device, this needs to be done for each new input | |
| for h_input, d_input in zip(host_inputs, device_inputs): | |
| cuda.memcpy_htod(d_input, h_input) | |
| # This needs to be called everytime your input shapes change | |
| # If your inputs are always the same shape (same batch size, etc.), | |
| # then you will only need to call this once | |
| host_outputs, device_outputs = setup_binding_shapes( | |
| engine, context, host_inputs, input_binding_idxs, output_binding_idxs, | |
| ) | |
| print("Output Shapes: {}".format([out.shape for out in host_outputs])) | |
| # Bindings are a list of device pointers for inputs and outputs | |
| bindings = device_inputs + device_outputs | |
| # Inference | |
| context.execute_v2(bindings) | |
| # Copy outputs back to host to view results | |
| for h_output, d_output in zip(host_outputs, device_outputs): | |
| cuda.memcpy_dtoh(h_output, d_output) | |
| # View outputs | |
| print(host_outputs) | |
| # Cleanup (Can also use context managers instead) | |
| del context | |
| del engine | |
| if __name__ == "__main__": | |
| main() | 
Hi @TerryBryant,
Please refer to this code block: https://gist.github.com/rmccorm4/dabccb1f31dbdcf1019a4df431067e52#file-dynamic_shape_inference-py-L28-L33
When all input binding shapes for an execution context have been specified (context.all_binding_shapes_specified==True), TensorRT should calculate the output binding shapes for that context automatically under the hood.
You can verify this by checking the context output binding shapes before and after setting the input binding shapes.
Hi @TerryBryant,
Please refer to this code block: https://gist.github.com/rmccorm4/dabccb1f31dbdcf1019a4df431067e52#file-dynamic_shape_inference-py-L28-L33
When all input binding shapes for an execution context have been specified (
context.all_binding_shapes_specified==True), TensorRT should calculate the output binding shapes for that context automatically under the hood.You can verify this by checking the context output binding shapes before and after setting the input binding shapes.
I see, thanks a lot !
Hi, I'm trying to run this script in multi thread. I want to load engine once, and create multi contexts for each thread, because different thread has different input size, so the binding shape is changing. But after I write it in this way, error occurs as follows,
[TensorRT] ERROR: Profile 0 has been chosen by another IExecutionContext. Use another profileIndex or destroy the IExecutionContext that use this profile.
[TensorRT] WARNING: Could not set default profile 0 for execution context. Profile index must be set explicitly.
[TensorRT] ERROR: Profile 0 has been chosen by another IExecutionContext. Use another profileIndex or destroy the IExecutionContext that use this profile.
I searched the problem, and only find some c++ docs, which tells that,
If the associated CUDA engine has dynamic inputs, this method must be called at least once with a unique profileIndex before calling execute or enqueue (i.e. the profile index may not be in use by another execution context that has not been destroyed yet). For the first execution context that is created for an engine, setOptimizationProfile(0) is called implicitly.
But I still don't know how to write the multi thread script. Could you help me? Thanks in advance!
Hi @TerryBryant,
[TensorRT] ERROR: Profile 0 has been chosen by another IExecutionContext. Use another profileIndex or destroy the IExecutionContext that use this profile.
You're likely using the same profile index on each thread, which isn't currently allowed.
At engine building time (single thread), you'll need to create at least as many optimization profiles as the number of threads you expect to be running simultaneously.
At runtime, for each thread you will likely need to do something like:
- Create a new execution context
- Assign the execution context an optimization profile that's not currently in use by another thread
For example, let's say you want to run 4 threads.
# --- Build time --- #
# Create 4 opt profiles ...
profile0 = builder.create_optimization_profile()
profile0.set_shape(...)
builder_config.add_optimization_profile(profile0) # profile_index=0
profile1 = builder.create_optimization_profile()
profile1.set_shape(...)
builder_config.add_optimization_profile(profile1) # profile_index=1
profile2 = builder.create_optimization_profile()
profile2.set_shape(...)
builder_config.add_optimization_profile(profile2) # profile_index=2
profile3 = builder.create_optimization_profile()
profile3.set_shape(...)
builder_config.add_optimization_profile(profile3) # profile_index=3
...
engine = builder.build_engine(network, builder_config)
# --- Inference time ---
# Create an execution context with a unique optimization profile for each thread
# thread 0
context0 = engine.create_execution_context()
context0.active_optimization_profile = 0
# thread 1
context1 = engine.create_execution_context()
context1.active_optimization_profile = 1
# thread 2
context2 = engine.create_execution_context()
context2.active_optimization_profile = 2
# thread 3
context3 = engine.create_execution_context()
context3.active_optimization_profile = 3
Hi, @rmccorm4 ,
Thank you for your sample code, I think I get it. But I still expect a better solution, because the input size of all my data are in the same range, which means I only need one kind of profile and all data can share. Also in inference time, it's not so convenient to assign different profile for each thread, because the thread may run in random.
Waiting for more instructions. Thanks.
Hi @TerryBryant,
You can define several profiles covering the same range of shapes. I agree it would be nice if you could re-use the same profile by multiple threads simultaneously, but I don't believe that's currently possible.
Hi @TerryBryant,
You can define several profiles covering the same range of shapes. I agree it would be nice if you could re-use the same profile, but I don't believe that's currently possible.
Ok, I got it. Thank you.
Hi @TerryBryant,
[TensorRT] ERROR: Profile 0 has been chosen by another IExecutionContext. Use another profileIndex or destroy the IExecutionContext that use this profile.
You're likely using the same profile index on each thread, which isn't currently allowed.
At engine building time (single thread), you'll need to create at least as many optimization profiles as the number of threads you expect to be running simultaneously.
At runtime, for each thread you will likely need to do something like:
- Create a new execution context
- Assign the execution context an optimization profile that's not currently in use by another thread
For example, let's say you want to run 4 threads.
# --- Build time --- # # Create 4 opt profiles ... profile0 = builder.create_optimization_profile() profile0.set_shape(...) builder_config.add_optimization_profile(profile0) # profile_index=0 profile1 = builder.create_optimization_profile() profile1.set_shape(...) builder_config.add_optimization_profile(profile1) # profile_index=1 profile2 = builder.create_optimization_profile() profile2.set_shape(...) builder_config.add_optimization_profile(profile2) # profile_index=2 profile3 = builder.create_optimization_profile() profile3.set_shape(...) builder_config.add_optimization_profile(profile3) # profile_index=3 ... engine = builder.build_engine(network, builder_config)# --- Inference time --- # Create an execution context with a unique optimization profile for each thread # thread 0 context0 = engine.create_execution_context() context0.active_optimization_profile = 0 # thread 1 context1 = engine.create_execution_context() context1.active_optimization_profile = 1 # thread 2 context2 = engine.create_execution_context() context2.active_optimization_profile = 2 # thread 3 context3 = engine.create_execution_context() context3.active_optimization_profile = 3
I've tried this solution, two problems occured:
1, the serialized engine file becomes very huge, due to I add 10 profiles
2, this kind of log appears, but I can still run the inference, don't know whether it's a warning message, it goes wrong
[TensorRT] WARNING: Total space of persistent layer space is 524160 on host and 3773253120 on device
[TensorRT] ERROR: ../rtSafe/safeRuntime.cpp (25) - Cuda Error in allocate: 2 (out of memory)
[TensorRT] ERROR: FAILED_ALLOCATION: std::exception
So I think it's a temporary solution, hope you and your official tensorrt team can take this multi thread problem into consideration.
Hi, as far as I'm concerned, in dynamic input shape condition, the output shape is calculated through input shape.
Assume the input shape maybe something like [1, 3, -1, -1], and output shape be like [1, 3, -1, -1].
We can assign input shape with the shape of input data, but the output shape should be calculated through running a forward of your model, so that you can allocate output buffer for gpu, but I can't see any relevant code in your example, could you give me some advice? Thanks in advance!