Created
May 15, 2020 00:25
-
-
Save tcyrus/3db1aee1a34fbd8e368c02264772c1a9 to your computer and use it in GitHub Desktop.
Potential way to Create ONNX models for waifu2x
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# This should be able to convert the pre-trained models from lltcggie/waifu2x-caffe to ONNX | |
# Based on the onnx converter script: | |
# https://github.com/onnx/onnx-docker/blob/master/onnx-ecosystem/converter_scripts/caffe_coreml_onnx.ipynb | |
import os | |
import coremltools | |
import onnxmltools | |
models_dir = 'waifu2x-caffe/models/' | |
def convert_model(model_name, model_path): | |
# Update your input name and path for your caffe model | |
proto_file = os.path.join(model_path, f"{model_name}.prototxt") | |
input_caffe_path = os.path.join(model_path, f"{model_name}.json.caffemodel") | |
# Update the output name and path for intermediate coreml model, or leave as is | |
output_coreml_model = os.path.join(model_path, f"{model_name}.mlmodel") | |
# Change this path to the output name and path for the onnx model | |
output_onnx_model = os.path.join(model_path, f"{model_name}.onnx") | |
# Convert Caffe model to CoreML | |
coreml_model = coremltools.converters.caffe.convert((input_caffe_path, proto_file)) | |
# Save CoreML model | |
coreml_model.save(output_coreml_model) | |
# Load CoreML model | |
coreml_model = coremltools.utils.load_spec(output_coreml_model) | |
# Convert CoreML model to ONNX | |
onnx_model = onnxmltools.convert_coreml(coreml_model) | |
# Save ONNX model | |
onnxmltools.utils.save_model(onnx_model, output_onnx_model) | |
def run(): | |
for (path, _, files) in os.walk(models_dir): | |
model_files = filter(lambda f: f.endswith('.caffemodel'), files) | |
for file in model_files: | |
try: | |
convert_model(file[:-16], path) | |
except: | |
pass | |
if __name__ == "__main__": | |
run() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment