Skip to content

Instantly share code, notes, and snippets.

@tiandiao123
Created September 7, 2022 22:51
Show Gist options
  • Save tiandiao123/fce9e60660f605af7034bf99186bff4a to your computer and use it in GitHub Desktop.
Save tiandiao123/fce9e60660f605af7034bf99186bff4a to your computer and use it in GitHub Desktop.
import torch
import torchvision
from torch.utils.mobile_optimizer import optimize_for_mobile
from caffe2.torch.fb.mobile.model_exporter.mobile_model_exporter import (
export_torch_mobile_model,
BundledInput,
MobileModelInfo,
ModelType,
OptimizationPassInput
)
from torch.quantization import get_default_qconfig, quantize_jit
from torch.quantization.quantize_fx import prepare_fx, convert_fx
from torch.quantization import quantize_fx
from torch.ao.quantization import QConfigMapping
def transform_model_to_lite_with_fx_quant(model, example_data, save_name, callable_data=None):
qconfig_mapping = torch.ao.quantization.get_default_qat_qconfig_mapping(torch.quantization.get_default_qconfig('qnnpack'))
print(type(qconfig_mapping))
print(qconfig_mapping)
model_to_quantize = copy.deepcopy(model)
model_to_quantize.eval()
model_prepared = quantize_fx.prepare_fx(model_to_quantize, new_qconfig_mapping, example_inputs=example_data)
quantized_model = quantize_fx.convert_fx(model_prepared)
quantized_model.eval()
torch_script_model = torch.jit.script(quantized_model)
out = torch_script_model(example_data)
# print(torch_script_model.graph)
optimized_module = optimize_for_mobile(torch_script_model)
mobile_model = export_torch_mobile_model(
optimized_module,
MobileModelInfo("target_recognition_detection", ModelType.D2Go),
BundledInput([(example_data,)]),
OptimizationPassInput(),
saved_path=save_name,
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment